# Author: Stefan Wunsch CERN  09/2019

################################################################################
# Copyright (C) 1995-2019, Rene Brun and Fons Rademakers.                      #
# All rights reserved.                                                         #
#                                                                              #
# For the licensing terms see $ROOTSYS/LICENSE.                                #
# For the list of contributors see $ROOTSYS/README/CREDITS.                    #
################################################################################

import json

import cppyy


def get_basescore(model):
    """Get base score from an XGBoost sklearn estimator.

    Copy-pasted from XGBoost unit test code.

    See also:
      * https://github.com/dmlc/xgboost/blob/2463938/python-package/xgboost/testing/updater.py#L43
      * https://github.com/dmlc/xgboost/issues/9347
      * https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
    """
    jintercept = json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"]
    out = json.loads(jintercept)
    if isinstance(out, float):
        return out
    # For XGBoost 3.1.0 and after, the value is itself a list.
    # However, we don't support multiple base scores yet.
    if len(out) > 1:
        raise ValueError(
            f"Model contains multiple base scores ({out}). "
            "This typically occurs with XGBoost ≥ 3.1.0, which supports multi-target base scores. "
            "This function only supports a single base score. "
        )
    return out[0]


def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
    """
    Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.

    Args:
        xgb_model: The trained XGBoost model.
        key_name (str): The name to use for storing the RBDT in the output file.
        output_path (str): The path to save the output file.
        num_inputs (int): The number of input features used in the model.

    Raises:
        Exception: If the XGBoost model has an unsupported objective.
    """
    # Extract objective
    objective_map = {
        "multi:softprob": "softmax",  # Naming the objective softmax is more common today
        "binary:logistic": "logistic",
        "reg:linear": "identity",
        "reg:squarederror": "identity",
    }
    model_objective = xgb_model.objective
    if model_objective not in objective_map:
        raise Exception(
            'XGBoost model has unsupported objective "{}". Supported objectives are {}.'.format(
                model_objective, objective_map.keys()
            )
        )
    objective = cppyy.gbl.std.string(objective_map[model_objective])

    # Determine number of outputs
    num_outputs = xgb_model.n_classes_ if "multi:" in model_objective else 1

    # Dump XGB model as json file
    xgb_model.get_booster().dump_model(output_path, dump_format="json")

    # Dump XGB model as txt file
    xgb_model.get_booster().dump_model(output_path)

    if xgb_model.get_booster().feature_names is None:
        features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
    else:
        features = cppyy.gbl.std.vector["std::string"](xgb_model.get_booster().feature_names)
    bs = get_basescore(xgb_model)
    logistic = objective == "logistic"
    bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(
        output_path,
        features,
        num_outputs,
        logistic,
        cppyy.gbl.std.log(bs / (1.0 - bs)) if logistic else bs,
    )

    with cppyy.gbl.TFile.Open(output_path, "RECREATE") as tFile:
        tFile.WriteObject(bdt, key_name)
