import logging import os import sys current_path = os.path.dirname(os.path.realpath(__file__)) app_specific_lib_path = os.path.join(current_path, "..", "lib") sys.path.insert(0, app_specific_lib_path) from splunkupgrade.utils.logger_utils import ( initialize_logger_for_completion_script, ) initialize_logger_for_completion_script() import splunk.entity as entity from splunkupgrade.utils.app_conf import RollingUpgradeConfig from splunkupgrade.utils.complete_upgrade import StandaloneUpgradeCompletion, UpgradeCompletion from splunkupgrade.utils.splunk_service import SplunkService from splunkupgrade.upgrader.upgrader_utils import try_fail_upgrade, UpgraderConfig, StatusUpdater from collections import namedtuple from typing import Tuple, Optional from splunkupgrade.data.parsing import get_field, DataParseException from splunkupgrade.utils.constants import GeneralConstants, SHC_UPGRADE_NOT_SUPPORTED_PEER from splunkupgrade.utils.server_roles_mapper import ServerRolesMapper from splunkupgrade.utils.utils import get_env_variable from splunkupgrade.upgrader.telemetry_utils import ( TelemetryStatus, telemetry_shc_log, role_to_telemetry_deployment_type, TELEMETRY_VERSION_UNKNOWN, ) HostPortServername = namedtuple("HostPortServername", ["host", "port", "servername"]) logger = logging.getLogger(__name__) def get_host_settings(session_key: str) -> HostPortServername: logger.info("Getting server settings") ent = entity.getEntity("/server", "settings", sessionKey=session_key, namespace="-", owner="-") host = get_field(ent, "host", str) port = get_field(ent, "mgmtHostPort", str) server_name = get_field(ent, "serverName", str) try: int_port = int(port) except ValueError: raise DataParseException("Failed to convert management port field to int") logger.info(f"Host='{host}', port='{int_port}', servername='{server_name}'") return HostPortServername(host, int_port, server_name) def try_getting_versions(service: SplunkService) -> Tuple[str, str]: from_version = TELEMETRY_VERSION_UNKNOWN to_version = TELEMETRY_VERSION_UNKNOWN try: curr_progress = StatusUpdater(service).get_current_progress() from_version = curr_progress.from_version to_version = curr_progress.to_version except Exception as e: logger.error(f"Cannot get from_version/to_version from the kvstore: {e}") return from_version, to_version def complete_upgrade() -> None: logger.info(f"Starting the completion script, pid: {os.getpid()}") service = None servername = None role_mapper = None try: trigger_file = os.path.join( get_env_variable(GeneralConstants.SPLUNK_HOME), GeneralConstants.UPGRADE_FILE_RELATIVE_PATH, ) logger.info(f"Trigger file path='{trigger_file}'") if not os.path.exists(trigger_file): logger.info("Trigger file does not exist") return os.remove(trigger_file) # todo: add a configurable timeout for read() logger.info("Reading session key from the stdin") session_key = sys.stdin.read() host, port, servername = get_host_settings(session_key) config = RollingUpgradeConfig() service = SplunkService.from_session_key( host, port, session_key, config.requests_timeout_config ) role_mapper = ServerRolesMapper(service.get_server_roles()) if role_mapper.is_deployer_only() or role_mapper.is_standalone_search_head(): completer = StandaloneUpgradeCompletion( service, UpgraderConfig(servername, trigger_file, config) ) elif role_mapper.is_shc_peer(): completer = UpgradeCompletion(service, UpgraderConfig(servername, trigger_file, config)) else: logger.info(SHC_UPGRADE_NOT_SUPPORTED_PEER) return if not completer.complete_upgrade(): raise Exception("Completion script failed") except Exception as e: logger.error(f"Error during upgrade finalisation: {e}") from_version, to_version = try_getting_versions(service) logger.error( telemetry_shc_log( "SHC Rolling Upgrade failed", -1, [servername], TelemetryStatus.FAILED, role_to_telemetry_deployment_type(role_mapper), "Error during upgrade finalisation", from_version, to_version, ) ) if service: try_fail_upgrade(service) if __name__ == "__main__": complete_upgrade()