Skip to content

Commit a971cfd

Browse files
committed
merge
2 parents dce4ea8 + 2d65405 commit a971cfd

File tree

5 files changed

+983
-503
lines changed

5 files changed

+983
-503
lines changed

images/python-autocomplete.png

227 KB
Loading

notebooks/evaluate.ipynb

+207-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,44 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"[![Github](https://img.shields.io/github/stars/lab-ml/python_autocomplete?style=social)](https://github.com/lab-ml/python_autocomplete)\n",
8+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/python_autocomplete/blob/master/notebooks/evaluate.ipynb)\n",
9+
"\n",
10+
"# Evaluate a model trained on predicting Python code\n",
11+
"\n",
12+
"This notebook evaluates a model trained on Python code.\n",
13+
"\n",
14+
"Here's a link to [training notebook](https://github.com/lab-ml/python_autocomplete/blob/master/notebooks/train.ipynb)\n",
15+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/python_autocomplete/blob/master/notebooks/train.ipynb)"
16+
]
17+
},
18+
{
19+
"cell_type": "markdown",
20+
"metadata": {},
21+
"source": [
22+
"### Install dependencies"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"%%capture\n",
32+
"!pip install labml labml_python_autocomplete"
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"Imports"
40+
]
41+
},
342
{
443
"cell_type": "code",
544
"execution_count": 1,
@@ -22,25 +61,31 @@
2261
"from python_autocomplete.evaluate import evaluate, anomalies, complete, Predictor"
2362
]
2463
},
64+
{
65+
"cell_type": "markdown",
66+
"metadata": {},
67+
"source": [
68+
"We load the model from a training run. For this demo I'm loading from a run I trained at home.\n",
69+
"\n",
70+
"[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=39b03a1e454011ebbaff2b26e3148b3d)\n",
71+
"\n",
72+
"*If you want to try this on Colab you need to run this on the same space where you run the training, because models are saved locally.*"
73+
]
74+
},
2575
{
2676
"cell_type": "code",
27-
"execution_count": 2,
77+
"execution_count": 1,
2878
"metadata": {},
29-
"outputs": [
30-
{
31-
"data": {
32-
"text/plain": [
33-
"'39b03a1e454011ebbaff2b26e3148b3d'"
34-
]
35-
},
36-
"execution_count": 2,
37-
"metadata": {},
38-
"output_type": "execute_result"
39-
}
40-
],
79+
"outputs": [],
4180
"source": [
42-
"TRAINING_RUN_UUID = '39b03a1e454011ebbaff2b26e3148b3d'\n",
43-
"TRAINING_RUN_UUID"
81+
"TRAINING_RUN_UUID = '39b03a1e454011ebbaff2b26e3148b3d'"
82+
]
83+
},
84+
{
85+
"cell_type": "markdown",
86+
"metadata": {},
87+
"source": [
88+
"We initialize `Configs` object defined in [`train.py`](https://github.com/lab-ml/python_autocomplete/blob/master/python_autocomplete/train.py)."
4489
]
4590
},
4691
{
@@ -49,10 +94,32 @@
4994
"metadata": {},
5095
"outputs": [],
5196
"source": [
52-
"conf = Configs()\n",
97+
"conf = Configs()"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"metadata": {},
103+
"source": [
104+
"Create a new experiment in evaluation mode. In evaluation mode a new training run is not created. "
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
53113
"experiment.evaluate()"
54114
]
55115
},
116+
{
117+
"cell_type": "markdown",
118+
"metadata": {},
119+
"source": [
120+
"Load custom configurations/hyper-parameters used in the training run."
121+
]
122+
},
56123
{
57124
"cell_type": "code",
58125
"execution_count": 4,
@@ -78,18 +145,15 @@
78145
}
79146
],
80147
"source": [
81-
"conf_dict = experiment.load_configs(TRAINING_RUN_UUID)\n",
82-
"conf_dict"
148+
"custom_conf = experiment.load_configs(TRAINING_RUN_UUID)\n",
149+
"custom_conf"
83150
]
84151
},
85152
{
86-
"cell_type": "code",
87-
"execution_count": 5,
153+
"cell_type": "markdown",
88154
"metadata": {},
89-
"outputs": [],
90155
"source": [
91-
"conf_dict['device.cuda_device'] = 1\n",
92-
"# conf_dict['device.use_cuda'] = False"
156+
"Set the custom configurations"
93157
]
94158
},
95159
{
@@ -111,7 +175,14 @@
111175
}
112176
],
113177
"source": [
114-
"experiment.configs(conf, conf_dict)"
178+
"experiment.configs(conf, custom_conf)"
179+
]
180+
},
181+
{
182+
"cell_type": "markdown",
183+
"metadata": {},
184+
"source": [
185+
"Set models for saving and loading. This will load `conf.model` from the specified run."
115186
]
116187
},
117188
{
@@ -150,6 +221,13 @@
150221
"experiment.add_pytorch_models({'model': conf.model})"
151222
]
152223
},
224+
{
225+
"cell_type": "markdown",
226+
"metadata": {},
227+
"source": [
228+
"Specify which run to load from"
229+
]
230+
},
153231
{
154232
"cell_type": "code",
155233
"execution_count": 8,
@@ -159,6 +237,13 @@
159237
"experiment.load(TRAINING_RUN_UUID)"
160238
]
161239
},
240+
{
241+
"cell_type": "markdown",
242+
"metadata": {},
243+
"source": [
244+
"Start the experiment"
245+
]
246+
},
162247
{
163248
"cell_type": "code",
164249
"execution_count": 9,
@@ -198,16 +283,47 @@
198283
"experiment.start()"
199284
]
200285
},
286+
{
287+
"cell_type": "markdown",
288+
"metadata": {},
289+
"source": [
290+
"Initialize the `Predictor` defined in [`evaluate.py`](https://github.com/lab-ml/python_autocomplete/blob/master/python_autocomplete/evaluate.py).\n",
291+
"\n",
292+
"We load `stoi` and `itos` from cache, so that we don't have to read the dataset to generate them. `stoi` is the map for character to an integer index and `itos` is the map of integer to character map. These indexes are used in the model embeddings for each character."
293+
]
294+
},
201295
{
202296
"cell_type": "code",
203297
"execution_count": 10,
204298
"metadata": {},
205299
"outputs": [],
206300
"source": [
207-
"p = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))\n",
301+
"p = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))"
302+
]
303+
},
304+
{
305+
"cell_type": "markdown",
306+
"metadata": {},
307+
"source": [
308+
"Set model to evaluation mode"
309+
]
310+
},
311+
{
312+
"cell_type": "code",
313+
"execution_count": null,
314+
"metadata": {},
315+
"outputs": [],
316+
"source": [
208317
"_ = conf.model.eval()"
209318
]
210319
},
320+
{
321+
"cell_type": "markdown",
322+
"metadata": {},
323+
"source": [
324+
"A python prompt to test completion."
325+
]
326+
},
211327
{
212328
"cell_type": "code",
213329
"execution_count": 11,
@@ -228,6 +344,13 @@
228344
" n_layers int):\"\"\""
229345
]
230346
},
347+
{
348+
"cell_type": "markdown",
349+
"metadata": {},
350+
"source": [
351+
"Get a token. `get_token` predicts character by character greedily (no beam search) until it find and end of token character (non alpha-numeric character)."
352+
]
353+
},
231354
{
232355
"cell_type": "code",
233356
"execution_count": 12,
@@ -250,6 +373,13 @@
250373
"print('\"' + res + '\"')"
251374
]
252375
},
376+
{
377+
"cell_type": "markdown",
378+
"metadata": {},
379+
"source": [
380+
"Try another token"
381+
]
382+
},
253383
{
254384
"cell_type": "code",
255385
"execution_count": 13,
@@ -264,11 +394,17 @@
264394
}
265395
],
266396
"source": [
267-
"PROMPT += res\n",
268-
"res = p.get_token(PROMPT)\n",
397+
"res = p.get_token(PROMPT + res)\n",
269398
"print('\"' + res + '\"')"
270399
]
271400
},
401+
{
402+
"cell_type": "markdown",
403+
"metadata": {},
404+
"source": [
405+
"Load a sample python file to test our model"
406+
]
407+
},
272408
{
273409
"cell_type": "code",
274410
"execution_count": 14,
@@ -293,6 +429,22 @@
293429
"print(sample[-50:])"
294430
]
295431
},
432+
{
433+
"cell_type": "markdown",
434+
"metadata": {},
435+
"source": [
436+
"## Test the model on a sample python file\n",
437+
"\n",
438+
"`evaluate` function defined in\n",
439+
"[`evaluate.py`](https://github.com/lab-ml/python_autocomplete/blob/master/python_autocomplete/evaluate.py)\n",
440+
"will predict token by token using the `Predictor`, and simulates an editor autocompletion.\n",
441+
"\n",
442+
"Colors:\n",
443+
"* <span style=\"color:yellow\">yellow</span>: the token predicted is wrong and the user needs to type that character.\n",
444+
"* <span style=\"color:blue\">blue</span>: the token predicted is correct and the user selects it with a special key press, such as TAB or ENTER.\n",
445+
"* <span style=\"color:green\">green</span>: autocompleted characters based on the prediction"
446+
]
447+
},
296448
{
297449
"cell_type": "code",
298450
"execution_count": 15,
@@ -434,6 +586,26 @@
434586
"evaluate(p, sample)"
435587
]
436588
},
589+
{
590+
"cell_type": "markdown",
591+
"metadata": {},
592+
"source": [
593+
"`accuracy` is the fraction of charactors predicted correctly. `key_strokes` is the number of key strokes required to write the code with help of the model and `length` is the number of characters in the code, that is the number of key strokes required without the model.\n",
594+
"\n",
595+
"*Note that this sample is a classic MNIST example, and the model must have overfitted to similar codes (exept for it's use of [LabML](https://github.com/lab-ml/labml) 😛).*"
596+
]
597+
},
598+
{
599+
"cell_type": "markdown",
600+
"metadata": {},
601+
"source": [
602+
"## Test anomalies in code\n",
603+
"\n",
604+
"We run the model through the same sample code and visualize the probabilty of predicting each character.\n",
605+
"<span style=\"color:green\">green</span> means the probabilty of that character is high and \n",
606+
"<span style=\"color:red\">red</span> means the probability is low."
607+
]
608+
},
437609
{
438610
"cell_type": "code",
439611
"execution_count": 16,
@@ -563,6 +735,13 @@
563735
"anomalies(p, sample)"
564736
]
565737
},
738+
{
739+
"cell_type": "markdown",
740+
"metadata": {},
741+
"source": [
742+
"Here we try to autocomplete 100 characters"
743+
]
744+
},
566745
{
567746
"cell_type": "code",
568747
"execution_count": 17,
@@ -633,7 +812,7 @@
633812
"name": "python",
634813
"nbconvert_exporter": "python",
635814
"pygments_lexer": "ipython3",
636-
"version": "3.8.5"
815+
"version": "3.7.5"
637816
}
638817
},
639818
"nbformat": 4,

0 commit comments

Comments
 (0)