You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
4.6 KiB

import logging
import os
import sys
current_path = os.path.dirname(os.path.realpath(__file__))
app_specific_lib_path = os.path.join(current_path, "..", "lib")
sys.path.insert(0, app_specific_lib_path)
from splunkupgrade.utils.logger_utils import (
initialize_logger_for_completion_script,
)
initialize_logger_for_completion_script()
import splunk.entity as entity
from splunkupgrade.utils.app_conf import RollingUpgradeConfig
from splunkupgrade.utils.complete_upgrade import StandaloneUpgradeCompletion, UpgradeCompletion
from splunkupgrade.utils.splunk_service import SplunkService
from splunkupgrade.upgrader.upgrader_utils import try_fail_upgrade, UpgraderConfig, StatusUpdater
from collections import namedtuple
from typing import Tuple, Optional
from splunkupgrade.data.parsing import get_field, DataParseException
from splunkupgrade.utils.constants import GeneralConstants, SHC_UPGRADE_NOT_SUPPORTED_PEER
from splunkupgrade.utils.server_roles_mapper import ServerRolesMapper
from splunkupgrade.utils.utils import get_env_variable
from splunkupgrade.upgrader.telemetry_utils import (
TelemetryStatus,
telemetry_shc_log,
role_to_telemetry_deployment_type,
TELEMETRY_VERSION_UNKNOWN,
)
HostPortServername = namedtuple("HostPortServername", ["host", "port", "servername"])
logger = logging.getLogger(__name__)
def get_host_settings(session_key: str) -> HostPortServername:
logger.info("Getting server settings")
ent = entity.getEntity("/server", "settings", sessionKey=session_key, namespace="-", owner="-")
host = get_field(ent, "host", str)
port = get_field(ent, "mgmtHostPort", str)
server_name = get_field(ent, "serverName", str)
try:
int_port = int(port)
except ValueError:
raise DataParseException("Failed to convert management port field to int")
logger.info(f"Host='{host}', port='{int_port}', servername='{server_name}'")
return HostPortServername(host, int_port, server_name)
def try_getting_versions(service: SplunkService) -> Tuple[str, str]:
from_version = TELEMETRY_VERSION_UNKNOWN
to_version = TELEMETRY_VERSION_UNKNOWN
try:
curr_progress = StatusUpdater(service).get_current_progress()
from_version = curr_progress.from_version
to_version = curr_progress.to_version
except Exception as e:
logger.error(f"Cannot get from_version/to_version from the kvstore: {e}")
return from_version, to_version
def complete_upgrade() -> None:
logger.info(f"Starting the completion script, pid: {os.getpid()}")
service = None
servername = None
role_mapper = None
try:
trigger_file = os.path.join(
get_env_variable(GeneralConstants.SPLUNK_HOME),
GeneralConstants.UPGRADE_FILE_RELATIVE_PATH,
)
logger.info(f"Trigger file path='{trigger_file}'")
if not os.path.exists(trigger_file):
logger.info("Trigger file does not exist")
return
os.remove(trigger_file)
# todo: add a configurable timeout for read()
logger.info("Reading session key from the stdin")
session_key = sys.stdin.read()
host, port, servername = get_host_settings(session_key)
config = RollingUpgradeConfig()
service = SplunkService.from_session_key(
host, port, session_key, config.requests_timeout_config
)
role_mapper = ServerRolesMapper(service.get_server_roles())
if role_mapper.is_deployer_only() or role_mapper.is_standalone_search_head():
completer = StandaloneUpgradeCompletion(
service, UpgraderConfig(servername, trigger_file, config)
)
elif role_mapper.is_shc_peer():
completer = UpgradeCompletion(service, UpgraderConfig(servername, trigger_file, config))
else:
logger.info(SHC_UPGRADE_NOT_SUPPORTED_PEER)
return
if not completer.complete_upgrade():
raise Exception("Completion script failed")
except Exception as e:
logger.error(f"Error during upgrade finalisation: {e}")
from_version, to_version = try_getting_versions(service)
logger.error(
telemetry_shc_log(
"SHC Rolling Upgrade failed",
-1,
[servername],
TelemetryStatus.FAILED,
role_to_telemetry_deployment_type(role_mapper),
"Error during upgrade finalisation",
from_version,
to_version,
)
)
if service:
try_fail_upgrade(service)
if __name__ == "__main__":
complete_upgrade()

Powered by BW's shoe-string budget.