diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index 66d8e91f..083cf0c6 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -21,7 +21,8 @@ from ..common.blackbox_explainer import BlackBoxExplainer from .model_distill import _model_distill -from .models import LGBMExplainableModel +from .models import LGBMExplainableModel, LinearExplainableModel, SGDExplainableModel, \ + DecisionTreeExplainableModel from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \ _aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \ _create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \ @@ -133,14 +134,19 @@ class MimicExplainer(BlackBoxExplainer): :param reset_index: Uses the pandas DataFrame index column as part of the features when training the surrogate model. :type reset_index: str + :param auto_select_explainable_model: Set this to 'True' if you want to use the MimicExplainer with an + auto-selected explainable model. We train four different explainable models LGBMExplainableModel, + LinearExplainableModel, SGDExplainableModel and DecisionTreeExplainableModel and score them to find + the best explainable model. This model is then used to derive explanations. + :type auto_select_explainable_model: bool """ - @init_tabular_decorator def __init__(self, model, initialization_examples, explainable_model, explainable_model_args=None, is_function=False, augment_data=True, max_num_of_augmentations=10, explain_subset=None, features=None, classes=None, transformations=None, allow_all_transformations=False, shap_values_output=ShapValuesOutput.DEFAULT, categorical_features=None, - model_task=ModelTask.Unknown, reset_index=ResetIndex.Ignore, **kwargs): + model_task=ModelTask.Unknown, reset_index=ResetIndex.Ignore, + auto_select_explainable_model=False, **kwargs): """Initialize the MimicExplainer. :param model: The black box model or function (if is_function is True) to be explained. Also known @@ -233,6 +239,11 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl the index when calling predict on the original model. Only use reset_teacher if the index is already featurized as part of the data. :type reset_index: str + :param auto_select_explainable_model: Set this to 'True' if you want to use the MimicExplainer with an + auto-selected explainable model. We train four different explainable models LGBMExplainableModel, + LinearExplainableModel, SGDExplainableModel and DecisionTreeExplainableModel and score them to find + the best explainable model. This model is then used to derive explanations. + :type auto_select_explainable_model: bool """ if transformations is not None and explain_subset is not None: raise ValueError("explain_subset not supported with transformations") @@ -250,8 +261,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function) super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function, model_task=eval_ml_domain, **kwargs) - if explainable_model_args is None: - explainable_model_args = {} + if categorical_features is None: categorical_features = [] self._logger.debug('Initializing MimicExplainer') @@ -288,7 +298,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl # Index the categorical string columns for training data self._column_indexer = initialization_examples.string_index(columns=categorical_features) self._one_hot_encoder = None - explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features else: # One-hot-encode categoricals for models that don't support categoricals natively self._column_indexer = initialization_examples.string_index(columns=categorical_features) @@ -304,14 +313,86 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl if isinstance(training_data, DenseData): training_data = training_data.data + self._original_eval_examples = None + self._allow_all_transformations = allow_all_transformations + + if auto_select_explainable_model: + # Train all available surrogate models to find the respective replication scores + explainable_model_list = [LGBMExplainableModel, LinearExplainableModel, + SGDExplainableModel, DecisionTreeExplainableModel] + self._best_replication_score = None + self._all_replication_scores = {} + for some_explainable_model in explainable_model_list: + try: + # Set params for explainable model + some_args = self._supplement_explainable_model_args( + explainable_model=some_explainable_model, + explainable_model_args={}, + categorical_features=categorical_features, + shap_values_output=shap_values_output) + # Train the explainable model + surrogate_model = _model_distill(self.function, some_explainable_model, training_data, + original_training_data, some_args) + # Compute the replication score between the teacher model and surrogate model + surrogate_replication_score = self._get_surrogate_model_replication_measure( + training_data=training_data, + surrogate_model=surrogate_model) + # Store the replication score + self._all_replication_scores[surrogate_model.method] = surrogate_replication_score + + # Keep track of the best score and the best trained surrogate model + if self._best_replication_score is None or \ + surrogate_replication_score > self._best_replication_score: + self.surrogate_model = surrogate_model + self._best_replication_score = surrogate_replication_score + except Exception: + pass + + if not auto_select_explainable_model or \ + (hasattr(self, "_best_replication_score") and self._best_replication_score is None): + # If the training/scoring of explainable model fails for some reason, + # then fall back on the user specified explainable model and train it. + explainable_model_args = self._supplement_explainable_model_args( + explainable_model=explainable_model, + explainable_model_args=explainable_model_args, + categorical_features=categorical_features, + shap_values_output=shap_values_output) + + self.surrogate_model = _model_distill( + self.function, explainable_model, training_data, + original_training_data, explainable_model_args) + + try: + surrogate_replication_score = None + # Compute the replication score between the teacher model and surrogate model + surrogate_replication_score = self._get_surrogate_model_replication_measure( + training_data=training_data, + surrogate_model=self.surrogate_model) + except Exception: + pass + finally: + # Store the replication score + self._all_replication_scores = {} + self._all_replication_scores[self.surrogate_model.method] = surrogate_replication_score + self._best_replication_score = surrogate_replication_score + + self._method = self.surrogate_model.method + + def _supplement_explainable_model_args(self, explainable_model, explainable_model_args, + categorical_features, shap_values_output): + if explainable_model_args is None: + explainable_model_args = {} + + if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \ + self._supports_categoricals(explainable_model): + explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features + explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag + if self._supports_shap_values_output(explainable_model): explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output - self.surrogate_model = _model_distill(self.function, explainable_model, training_data, - original_training_data, explainable_model_args) - self._method = self.surrogate_model._method - self._original_eval_examples = None - self._allow_all_transformations = allow_all_transformations + + return explainable_model_args def _supports_categoricals(self, explainable_model): return issubclass(explainable_model, LGBMExplainableModel) @@ -630,6 +711,43 @@ def _load(model, properties): mimic.__dict__[MimicSerializationConstants.ALLOW_ALL_TRANSFORMATIONS] = False return mimic + def _get_surrogate_model_replication_measure(self, training_data, surrogate_model): + """Return the metric which tells how well the surrogate model replicates the teacher model. + + :param training_data: The data for getting the replication metric. + :type training_data: numpy.array or pandas.DataFrame or iml.datatypes.DenseData or + scipy.sparse.csr_matrix + :param surrogate_model: Trained surrogate model. + :type surrogate_model: Any + :return: Metric that tells how well the surrogate model replicates the behavior of teacher model. + :rtype: float + """ + try: + from sklearn.metrics import accuracy_score + from sklearn.metrics import r2_score + sklearn_metrics_available = True + except ImportError: + sklearn_metrics_available = False + + if not sklearn_metrics_available: + raise Exception( + "Cannot compute replication metrics due to missing sklearn metrics package") + + surrogate_model_predictions = surrogate_model.predict(training_data) + teacher_model_predictions = self.model.predict(training_data) + + if self.classes is not None: + if len(self.classes) > 2: + replication_measure = accuracy_score(teacher_model_predictions, surrogate_model_predictions) + else: + raise Exception("Replication measure is not supported for binary classification") + else: + if training_data.shape[0] == 1: + raise Exception("Replication measure for regression surrogate not supported " + "because of single instance in training data") + replication_measure = r2_score(teacher_model_predictions, surrogate_model_predictions) + return replication_measure + def __getstate__(self): """Influence how MimicExplainer is pickled. diff --git a/python/interpret_community/mimic/models/explainable_model.py b/python/interpret_community/mimic/models/explainable_model.py index 5fd52feb..29253e21 100644 --- a/python/interpret_community/mimic/models/explainable_model.py +++ b/python/interpret_community/mimic/models/explainable_model.py @@ -82,6 +82,11 @@ def explainable_model_type(self): """Retrieve the model type.""" pass + @property + def method(self): + """Return the name of the explainable model.""" + return self._method + def __getstate__(self): """Influence how SGDExplainableModel is pickled. diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 48eb2d73..ba144433 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -451,8 +451,13 @@ def test_explain_model_string_classes(self, mimic_explainer): transformations=feat_pipe) global_explanation = explainer.explain_global(X.iloc[:1000]) assert global_explanation.method == LINEAR_METHOD + assert explainer._all_replication_scores is not None + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is None + assert explainer._best_replication_score is None - def test_linear_explainable_model_regression(self, mimic_explainer): + @pytest.mark.parametrize('auto_select_explainable_model', [True, False]) + def test_linear_explainable_model_regression(self, mimic_explainer, auto_select_explainable_model): num_features = 3 x_train = np.array([['a', 'E', 'x'], ['c', 'D', 'y']]) y_train = np.array([1, 2]) @@ -464,11 +469,29 @@ def test_linear_explainable_model_regression(self, mimic_explainer): explainable_model = LinearExplainableModel explainer = mimic_explainer(model.named_steps['regressor'], x_train, explainable_model, transformations=transformations, augment_data=False, + auto_select_explainable_model=auto_select_explainable_model, explainable_model_args={'sparse_data': True}, features=['f1', 'f2', 'f3']) global_explanation = explainer.explain_global(x_train) assert global_explanation.method == LINEAR_METHOD - - def test_linear_explainable_model_classification(self, mimic_explainer): + assert explainer._all_replication_scores is not None + assert explainer._best_replication_score is not None + if not auto_select_explainable_model: + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is not None + else: + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is not None + assert 'sgd' in explainer._all_replication_scores + assert explainer._all_replication_scores['sgd'] is not None + assert 'lightgbm' in explainer._all_replication_scores + assert explainer._all_replication_scores['lightgbm'] is not None + assert 'tree' in explainer._all_replication_scores + assert explainer._all_replication_scores['tree'] is not None + + @pytest.mark.parametrize('if_multiclass', [True, False]) + @pytest.mark.parametrize('auto_select_explainable_model', [True, False]) + def test_linear_explainable_model_classification(self, mimic_explainer, if_multiclass, + auto_select_explainable_model): n_samples = 100 n_cat_features = 15 @@ -476,7 +499,12 @@ def test_linear_explainable_model_classification(self, mimic_explainer): cat_features = np.random.choice(['a', 'b', 'c', 'd'], (n_samples, n_cat_features)) data_x = pd.DataFrame(cat_features, columns=cat_feature_names) - data_y = np.random.choice(['0', '1'], n_samples) + if if_multiclass: + data_y = np.random.choice([0, 1, 2, 3], n_samples) + classes = [0, 1, 2, 3] + else: + data_y = np.random.choice([0, 1], n_samples) + classes = [0, 1] # prepare feature encoders cat_feature_encoders = [OneHotEncoder().fit(cat_features[:, i].reshape(-1, 1)) for i in range(n_cat_features)] @@ -498,11 +526,35 @@ def test_linear_explainable_model_classification(self, mimic_explainer): explainable_model_args={'sparse_data': True}, augment_data=False, features=cat_feature_names, - classes=['0', '1'], + classes=classes, + auto_select_explainable_model=auto_select_explainable_model, transformations=cat_transformations, model_task=ModelTask.Classification) global_explanation = explainer.explain_global(evaluation_examples=data_x) - assert global_explanation.method == LINEAR_METHOD + + if if_multiclass: + assert explainer._all_replication_scores is not None + assert explainer._best_replication_score is not None + if not auto_select_explainable_model: + assert global_explanation.method == LINEAR_METHOD + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is not None + else: + assert global_explanation.method == LIGHTGBM_METHOD + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is not None + assert 'sgd' in explainer._all_replication_scores + assert explainer._all_replication_scores['sgd'] is not None + assert 'lightgbm' in explainer._all_replication_scores + assert explainer._all_replication_scores['lightgbm'] is not None + assert 'tree' in explainer._all_replication_scores + assert explainer._all_replication_scores['tree'] is not None + else: + assert global_explanation.method == LINEAR_METHOD + assert explainer._all_replication_scores is not None + assert 'linear' in explainer._all_replication_scores + assert explainer._all_replication_scores['linear'] is None + assert explainer._best_replication_score is None def test_dense_wide_data(self, mimic_explainer): # use 6000 rows instead for real performance testing