|
| 1 | +import argparse |
| 2 | +import io |
| 3 | +import json |
| 4 | +import random |
| 5 | + |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +from pyarrow.parquet import ParquetFile |
| 9 | +import shap |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +BACKGROUND_SIZE = 1000 |
| 13 | +SAMPLE_SIZE = 500 |
| 14 | + |
| 15 | + |
| 16 | +# Implements reservoir sampling |
| 17 | +def update_sample(samples, N, sample): |
| 18 | + if sample is None: |
| 19 | + return |
| 20 | + |
| 21 | + if len(samples) < BACKGROUND_SIZE: |
| 22 | + samples.append(str(sample)) |
| 23 | + else: |
| 24 | + s = int(random.random() * N) |
| 25 | + if s < BACKGROUND_SIZE: |
| 26 | + samples[s] = str(sample) |
| 27 | + |
| 28 | + |
| 29 | +class_names = np.load("classes.npy", allow_pickle=True) |
| 30 | +parser = argparse.ArgumentParser() |
| 31 | +parser.add_argument("class_name", choices=class_names) |
| 32 | +args = parser.parse_args() |
| 33 | + |
| 34 | +# Get indexes of samples matching the given |
| 35 | +# class in the first SAMPLE_SIZE values |
| 36 | +pq_labels = ParquetFile("../sherlock-project/data/data/raw/train_labels.parquet") |
| 37 | +class_idx = list( |
| 38 | + np.where( |
| 39 | + pq_labels.read(columns=["type"]).columns[0].to_numpy()[:SAMPLE_SIZE] |
| 40 | + == args.class_name |
| 41 | + )[0] |
| 42 | +) |
| 43 | + |
| 44 | +# See https://github.com/slundberg/shap/issues/1406 |
| 45 | +shap.explainers._deep.deep_tf.op_handlers[ |
| 46 | + "AddV2" |
| 47 | +] = shap.explainers._deep.deep_tf.passthrough |
| 48 | + |
| 49 | +# Load the trained model |
| 50 | +model = tf.keras.models.model_from_json(open("nn_model_sherlock.json").read()) |
| 51 | +model.load_weights("nn_model_weights_sherlock.h5") |
| 52 | + |
| 53 | +# Produce a randomly sample of background from the training data |
| 54 | +background = [] |
| 55 | +for (i, line) in enumerate(open("preprocessed_train.txt")): |
| 56 | + update_sample(background, i, line) |
| 57 | + |
| 58 | +matrix = np.loadtxt(io.StringIO("".join(background))) |
| 59 | +del background |
| 60 | + |
| 61 | +# Load sample values matching the given class |
| 62 | +sample = np.loadtxt(open("preprocessed_train.txt", "r"), max_rows=SAMPLE_SIZE)[ |
| 63 | + class_idx, : |
| 64 | +] |
| 65 | + |
| 66 | +# Use SHAP to create a summary plot |
| 67 | +e = shap.DeepExplainer(model, matrix) |
| 68 | +shap_values = e.shap_values(sample) |
| 69 | +feature_names = [l.strip() for l in open("pattern_ids.txt")] |
| 70 | +shap.summary_plot( |
| 71 | + shap_values, sample, class_names=class_names, feature_names=feature_names |
| 72 | +) |
| 73 | +plt.savefig("shap.png") |
0 commit comments