Skip to content

Commit 1f06dcb

Browse files
committed
Correct paths and add SHAP explanations
1 parent 9f20559 commit 1f06dcb

8 files changed

+537
-19
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@
77
pages/
88
patterns.json
99
*.npy
10+
*.json
11+
*.h5
12+
*.png
13+
regex101/

Pipfile

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ tqdm = "*"
1313
pandas = "*"
1414
scikit-learn = "*"
1515
tensorflow = "*"
16+
shap = "*"
17+
matplotlib = "*"
1618

1719
[dev-packages]
1820

Pipfile.lock

+390-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,57 @@
11
# Learning from Uncurated Regular Expressions
22

33
Dependencies of all Python code are managed with [Pipenv](https://pipenv.pypa.io/en/latest/) and can be installed with `pipenv install`.
4-
Note that the dataset from the [Sherlock](https://github.com/mitmedialab/sherlock-project) project should be available in a copy of the repository in the same directory as this project.
4+
Note that the dataset from the [Sherlock](https://github.com/mitmedialab/sherlock-project) project should be available in a copy of the repository in alongside the directory for this project.
55
[`jq`](https://jqlang.github.io/jq/) is also required for some JSON processing.
66

7+
## Model training
8+
79
1. Download all regular expressions from regex101
810

911
`./download_patterns.sh`
1012

13+
This will create a directory `regex101` which has the individual regular expressions and `patterns.json` which contains only the expressions strings.
14+
1115
2. Compile a database of all the downloaded regular expressions
1216

13-
`pipenv run python compile_db.py < patterns.json`
17+
`pipenv run python compile_db.py < patterns.json > patterns_final.json`
18+
19+
`patterns_final.json` is a subset of the expressions in `patterns.json` which are supported by Hyperscan.
20+
This step will also create `hs.db` which are the compiled regular expressions that can be used during preprocessing.
1421

1522
3. Preprocess the data to generate feature vectors
1623

1724
`pipenv run python preprocess.py train`
1825

26+
This will generate `preprocessed_train.txt` which contains all the feature vectors extracted using the regular expression extracted using the regular expressions.
27+
1928
4. Train the model on the extracted features
2029

2130
`pipenv run python train.py`
31+
32+
The model architecture will be stored in `nn_model_sherlock.json` with the weights in `nn_model_weights_sherlock.h5`.
33+
34+
## Evaluation
35+
36+
First, the test data must be preprocessed.
37+
38+
`pipenv run python preprocess.py test`
39+
40+
Then, the model can be evaluated.
41+
42+
`pipenv run python test.py`
43+
44+
## Model explanation
45+
46+
Explains for predictions for an individual class can be generated using [SHAP](https://shap.readthedocs.io/en/latest/).
47+
First, follow the steps for training the model above.
48+
The file `patterns_final.json` will be used to match the patterns back to the original regular expressions.
49+
50+
`pipenv run python find_patterns.py > pattern_ids.txt`
51+
52+
This file of pattern IDs will then be used to label the SHAP plot with the ID of the regular expression.
53+
To generate the SHAP plot in `shap.png`, run the command below where `<class_name>` is one of the semantic types defined by Sherlock.
54+
55+
`pipenv run python explain.py <class_name>`
56+
57+
The IDs displayed in the SHAP plot can be used to reference the regular expressions by ID in the `regex101/patterns` directory or viewing it directly on regex101 at the URL `https://regex101.com/library/<ID>`.

compile_db.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@
2020
except hyperscan.error as e:
2121
pass
2222

23+
# Build input for the final Hyperscan database
2324
db = hyperscan.Database()
24-
num_patterns = 0
2525
patterns = []
2626
ids = []
2727
flags = []
28-
for regex in regexes:
28+
for (i, regex) in enumerate(regexes):
29+
print(json.dumps(regex))
2930
patterns.append(regex.encode("utf8"))
30-
ids.append(num_patterns)
31+
ids.append(i)
3132
flags.append(hyperscan.HS_FLAG_SINGLEMATCH | hyperscan.HS_FLAG_UTF8)
32-
num_patterns += 1
3333

34+
# Compile the final database and save to file
3435
sys.stderr.write("Compiling %d patterns...\n" % len(patterns))
3536
db.compile(expressions=patterns, ids=ids, flags=flags)
3637
with open("hs.db", "wb") as f:
37-
pickle.dump([num_patterns, hyperscan.dumpb(db)], f)
38+
pickle.dump([len(patterns), hyperscan.dumpb(db)], f)

download_patterns.sh

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
#!/bin/bash
22

33
# Download each page of search results
4-
mkdir -p pages/
5-
wget "https://regex101.com/api/library/1/?orderBy=MOST_POINTS&search=" -O pages/1.json
6-
PAGES=$(jq -r .pages pages/1.json)
4+
mkdir -p regex101/pages/
5+
wget "https://regex101.com/api/library/1/?orderBy=MOST_POINTS&search=" -O regex101/pages/1.json
6+
PAGES=$(jq -r .pages regex101/pages/1.json)
77
for i in $(seq 2 $PAGES); do
88
# Fetch this page of regular expressions
9-
wget "https://regex101.com/api/library/$i/?orderBy=MOST_POINTS&search=" -O "pages/$i.json"
9+
wget "https://regex101.com/api/library/$i/?orderBy=MOST_POINTS&search=" -O "regex101/pages/$i.json"
1010
sleep 1
1111
done
1212

1313
# Extract all fragments from each page to get individual regexes
14-
mkdir -p regexes/
15-
jq -cr '.data[] | (.permalinkFragment + " https://regex101.com/api/regex/" + .permalinkFragment + "/" + (.version | tostring))' pages/*.json | \
14+
mkdir -p regex101/regexes/
15+
jq -cr '.data[] | (.permalinkFragment + " https://regex101.com/api/regex/" + .permalinkFragment + "/" + (.version | tostring))' regex101/pages/*.json | \
1616
while read -r frag url; do
1717
# If the regex has not already been fetched, fetch it
18-
[ -f "regexes/$frag.json" ] || (wget -O "regexes/$frag.json" -nc "$url"; sleep 1)
18+
[ -f "regex101/regexes/$frag.json" ] || (wget -O "regex101/regexes/$frag.json" -nc "$url"; sleep 1)
1919
done
2020

2121
# Extract all PCRE regexes without newlines into a file
22-
jq -c 'select((.flavor == "pcre") and (.regex | contains( "\n") | not)) .regex' ../regex101/regexes/* > patterns.json
22+
jq -c 'select((.flavor == "pcre") and (.regex | contains( "\n") | not)) .regex' regex101/regexes/* > patterns.json

explain.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import io
3+
import json
4+
import random
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
from pyarrow.parquet import ParquetFile
9+
import shap
10+
import tensorflow as tf
11+
12+
BACKGROUND_SIZE = 1000
13+
SAMPLE_SIZE = 500
14+
15+
16+
# Implements reservoir sampling
17+
def update_sample(samples, N, sample):
18+
if sample is None:
19+
return
20+
21+
if len(samples) < BACKGROUND_SIZE:
22+
samples.append(str(sample))
23+
else:
24+
s = int(random.random() * N)
25+
if s < BACKGROUND_SIZE:
26+
samples[s] = str(sample)
27+
28+
29+
class_names = np.load("classes.npy", allow_pickle=True)
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument("class_name", choices=class_names)
32+
args = parser.parse_args()
33+
34+
# Get indexes of samples matching the given
35+
# class in the first SAMPLE_SIZE values
36+
pq_labels = ParquetFile("../sherlock-project/data/data/raw/train_labels.parquet")
37+
class_idx = list(
38+
np.where(
39+
pq_labels.read(columns=["type"]).columns[0].to_numpy()[:SAMPLE_SIZE]
40+
== args.class_name
41+
)[0]
42+
)
43+
44+
# See https://github.com/slundberg/shap/issues/1406
45+
shap.explainers._deep.deep_tf.op_handlers[
46+
"AddV2"
47+
] = shap.explainers._deep.deep_tf.passthrough
48+
49+
# Load the trained model
50+
model = tf.keras.models.model_from_json(open("nn_model_sherlock.json").read())
51+
model.load_weights("nn_model_weights_sherlock.h5")
52+
53+
# Produce a randomly sample of background from the training data
54+
background = []
55+
for (i, line) in enumerate(open("preprocessed_train.txt")):
56+
update_sample(background, i, line)
57+
58+
matrix = np.loadtxt(io.StringIO("".join(background)))
59+
del background
60+
61+
# Load sample values matching the given class
62+
sample = np.loadtxt(open("preprocessed_train.txt", "r"), max_rows=SAMPLE_SIZE)[
63+
class_idx, :
64+
]
65+
66+
# Use SHAP to create a summary plot
67+
e = shap.DeepExplainer(model, matrix)
68+
shap_values = e.shap_values(sample)
69+
feature_names = [l.strip() for l in open("pattern_ids.txt")]
70+
shap.summary_plot(
71+
shap_values, sample, class_names=class_names, feature_names=feature_names
72+
)
73+
plt.savefig("shap.png")

find_patterns.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import glob
2+
import json
3+
4+
# Create a dictionary of possible patterns
5+
pat_dict = {}
6+
for file in glob.glob("regex101/regexes/*.json"):
7+
try:
8+
obj = json.load(open(file))
9+
pat_dict[obj["regex"]] = file.split("/")[-1].split(".")[0]
10+
except json.decoder.JSONDecodeError:
11+
pass
12+
13+
# Output the index and ID of each pattern
14+
for line in open("patterns_final.json"):
15+
pat = json.loads(line)
16+
print(pat_dict[pat])

0 commit comments

Comments
 (0)