Skip to content

Commit 043e4c7

Browse files
committed
Merge branch 'master' of github.com:lab-ml/python_autocomplete
merge
2 parents 03915d7 + 6ab0b1d commit 043e4c7

17 files changed

+2188
-956
lines changed

notebooks/evaluate.ipynb

+47-486
Large diffs are not rendered by default.

notebooks/evaluate_old.ipynb

+944
Large diffs are not rendered by default.

notebooks/highlight.ipynb

+189
Large diffs are not rendered by default.

python_autocomplete/bundle.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from labml import experiment, lab
22

33
if __name__ == '__main__':
4-
experiment.save_bundle(lab.get_path() / 'bundle.tar.gz', '39b03a1e454011ebbaff2b26e3148b3d',
5-
data_files=['cache/itos.json', 'cache/n_tokens.json', 'cache/stoi.json'])
4+
experiment.save_bundle(lab.get_path() / 'bundle.tar.gz', 'a6cff3706ec411ebadd9bf753b33bae6',
5+
data_files=['cache/itos.json',
6+
'cache/n_tokens.json',
7+
'cache/stoi.json',
8+
'cache/bpe.json',
9+
])
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import string
2+
from typing import Dict, List, Tuple
3+
4+
ID_CHARS = set(string.ascii_letters + string.digits + '_')
5+
6+
7+
class Tokenizer:
8+
n_tokens: int
9+
itos: List[str]
10+
stoi: Dict[str, int]
11+
is_trained: int
12+
13+
def encode(self, data: str, *, is_silent: bool = True):
14+
raise NotImplementedError
15+
16+
def train(self, data: str):
17+
pass
18+
19+
def rstrip(self, data: str) -> Tuple[str, List[int]]:
20+
return data, self.encode(data)

python_autocomplete/bpe.py renamed to python_autocomplete/dataset/bpe.py

+122-83
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,146 @@
1-
import string
1+
from functools import lru_cache
22
from heapq import heappush, heappop
3-
from typing import List, Tuple
3+
from typing import List
44

55
from labml import lab, monit
6+
from labml.utils.cache import cache_set
7+
from python_autocomplete.dataset import Tokenizer
8+
from python_autocomplete.dataset.break_words import SourceCodeTokenizer
69

7-
ID_CHARS = set(string.ascii_letters + string.digits + '_')
810

11+
class BPE(Tokenizer):
12+
def __init__(self, bpe_en_de: 'BPEEnDe', word_tokenizer):
13+
self.bpe = bpe_en_de
14+
self.word_tokenizer = word_tokenizer
15+
self.is_trained = True
916

10-
class BPE:
11-
def __init__(self):
12-
self.char_itos = []
13-
self.char_stoi = {}
14-
self.bpe_itos = []
15-
self.bpe = []
16-
self.common = {}
17+
@property
18+
def n_tokens(self):
19+
return len(self.bpe.bpe)
1720

18-
self.bpe_itos = self.calc_bpe_itos()
21+
@property
22+
def itos(self):
23+
return self.bpe.bpe_itos
1924

20-
def to_char_stoi(self, w: str):
21-
return [self.char_stoi[c] for c in w]
25+
@property
26+
def stoi(self):
27+
return self.bpe.bpe_stoi
2228

23-
def calc_bpe_itos(self):
24-
itos = list(self.char_itos)
25-
itos += [itos[p1] + itos[p2] for p1, p2 in self.bpe[len(self.char_itos):]]
26-
return itos
29+
def encode(self, data: str, *, is_silent: bool = True):
30+
words = self.word_tokenizer.tokenize(data, is_silent=is_silent)
2731

32+
res = []
33+
for w in monit.iterate('Encode words', words, is_silent=is_silent):
34+
res += self.bpe.encode(w)
2835

29-
class Tokenizer:
30-
def collect_words(self, data: str):
31-
raise NotImplementedError
36+
return res
3237

33-
def get_words(self) -> Tuple[List[str], List[int]]:
34-
raise NotImplementedError
38+
def __call__(self, data: str):
39+
encoded = self.encode(data)
40+
return [self.itos[c] for c in encoded]
3541

36-
def tokenize(self, data: str) -> List[str]:
37-
raise NotImplementedError
42+
def rstrip(self, data: str):
43+
words = self.word_tokenizer.tokenize(data, is_silent=True)
44+
words = words[:-1]
45+
res = []
46+
for w in words:
47+
res += self.bpe.encode(w)
3848

49+
return ''.join(words), res
3950

40-
class SourceCodeTokenizer(Tokenizer):
41-
def __init__(self):
42-
self.words = {}
4351

44-
def add_word(self, word):
45-
if not word:
46-
return
52+
class _BPEEncoder:
53+
def __init__(self, pairs):
54+
self.pairs = pairs
55+
self.codes = []
56+
self.next_idx = []
57+
self.prev_idx = []
58+
self.heap = []
4759

48-
if word not in self.words:
49-
self.words[word] = 1
50-
else:
51-
self.words[word] += 1
60+
def encode(self, codes: List[int]):
61+
self.codes = codes
62+
self.next_idx = BPELearner.default_next_pointers(len(codes))
63+
self.prev_idx = BPELearner.default_prev_pointers(len(codes))
64+
self.heap = []
5265

53-
def tokenize(self, data: str) -> List[str]:
54-
last_idx = 0
55-
is_id = False
56-
res = []
66+
for i in range(len(self.codes) - 1):
67+
self.add_pair((self.codes[i], self.codes[i + 1]), i)
5768

58-
for i, c in monit.enum('Collect words', data):
59-
if c in ID_CHARS:
60-
if not is_id:
61-
if last_idx < i:
62-
res.append(data[last_idx:i])
63-
last_idx = i
64-
is_id = True
65-
else:
66-
if is_id:
67-
if last_idx < i:
68-
res.append(data[last_idx:i])
69-
last_idx = i
70-
is_id = False
71-
72-
if last_idx < len(data):
73-
res.append(data[last_idx:])
69+
while self.heap:
70+
_, idx, pair = heappop(self.heap)
71+
self.merge(idx, pair)
7472

75-
return res
73+
return [c for c in self.codes if c != -1]
7674

77-
def collect_words(self, data: str):
78-
last_idx = 0
79-
is_id = False
75+
def merge(self, p2, pair):
76+
p3 = self.next_idx[p2]
77+
78+
if p3 == -1 or pair[0] != self.codes[p2] or pair[1] != self.codes[p3]:
79+
return
8080

81-
for i, c in monit.enum('Collect words', data):
82-
if c in ID_CHARS:
83-
if not is_id:
84-
self.add_word(data[last_idx:i])
85-
last_idx = i
86-
is_id = True
87-
else:
88-
if is_id:
89-
self.add_word(data[last_idx:i])
90-
last_idx = i
91-
is_id = False
81+
self.codes[p2] = self.pairs[pair]
82+
self.codes[p3] = -1
83+
p1 = self.prev_idx[p2]
84+
p4 = self.next_idx[p3]
9285

93-
self.add_word(data[last_idx:])
86+
if p1 != -1:
87+
self.add_pair((self.codes[p1], self.codes[p2]), p1)
88+
self.next_idx[p2] = p4
89+
if p4 != -1:
90+
self.prev_idx[p4] = p2
91+
self.add_pair((self.codes[p2], self.codes[p4]), p2)
9492

95-
def get_words(self):
96-
words_list = [(f, w) for w, f in self.words.items()]
97-
words_list.sort(key=lambda x: -x[0])
93+
def add_pair(self, pair, idx):
94+
if pair not in self.pairs:
95+
return
9896

99-
return [w for _, w in words_list], [f for f, _ in words_list]
97+
heappush(self.heap, (self.pairs[pair], idx, pair))
10098

10199

102-
class NoTokenizer(Tokenizer):
100+
class BPEEnDe:
103101
def __init__(self):
104-
self.data = ''
102+
self.char_itos = []
103+
self.char_stoi = {}
104+
self.bpe = []
105+
self.popular_words = {}
106+
107+
self.bpe_itos = []
108+
self.bpe_stoi = {}
109+
self.pairs = {}
110+
self.encoder = None
111+
112+
def load(self, char_itos, char_stoi, bpe):
113+
self.char_itos = char_itos
114+
self.char_stoi = char_stoi
115+
self.bpe = bpe
116+
117+
self.calc()
118+
119+
def set_popular_words(self, popular_words):
120+
self.popular_words = popular_words
121+
122+
def calc(self):
123+
self.bpe_itos = self.calc_bpe_itos()
124+
self.bpe_stoi = {s: i for i, s in enumerate(self.bpe_itos)}
125+
self.pairs = {(p[0], p[1]): c for c, p in enumerate(self.bpe) if not isinstance(p, int)}
105126

106-
def collect_words(self, data):
107-
self.data += data
127+
self.encoder = _BPEEncoder(self.pairs)
108128

109-
def get_words(self):
110-
return [self.data], [1]
129+
def to_char_stoi(self, w: str):
130+
return [self.char_stoi[c] for c in w]
111131

112-
def tokenize(self, data: str) -> List[str]:
113-
return [data]
132+
def calc_bpe_itos(self):
133+
itos = list(self.char_itos)
134+
for p1, p2 in self.bpe[len(self.char_itos):]:
135+
itos.append(itos[p1] + itos[p2])
136+
return itos
137+
138+
@lru_cache(1024)
139+
def encode(self, word: str):
140+
if word in self.popular_words:
141+
return self.popular_words[word]
142+
143+
return self.encoder.encode([self.char_stoi[c] for c in word if c in self.char_stoi])
114144

115145

116146
class BPELearner:
@@ -284,7 +314,7 @@ def main():
284314
path = lab.get_data_path() / 'train.py'
285315

286316
with open(str(path), 'r') as f:
287-
data = f.read()[:100_000]
317+
data = f.read()
288318

289319
tokenizer = SourceCodeTokenizer()
290320
tokenizer.collect_words(data)
@@ -295,6 +325,15 @@ def main():
295325
print(bpe.bpe_itos()[len(bpe.char_itos):])
296326
print(len(data), bpe.get_length())
297327

328+
cache_set('bpe', {
329+
'char_itos': bpe.char_itos,
330+
'char_stoi': bpe.char_stoi,
331+
'bpe': bpe.bpe
332+
})
333+
334+
bpe_en_de = BPEEnDe()
335+
bpe_en_de.load(bpe.char_itos, bpe.char_stoi, bpe.bpe)
336+
298337

299338
if __name__ == '__main__':
300339
main()
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import List, Tuple
2+
3+
from labml import monit
4+
from python_autocomplete.dataset import ID_CHARS
5+
6+
7+
class WordTokenizer:
8+
def collect_words(self, data: str):
9+
raise NotImplementedError
10+
11+
def get_words(self) -> Tuple[List[str], List[int]]:
12+
raise NotImplementedError
13+
14+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
15+
raise NotImplementedError
16+
17+
18+
class SourceCodeTokenizer(WordTokenizer):
19+
def __init__(self):
20+
self.words = {}
21+
22+
def add_word(self, word):
23+
if not word:
24+
return
25+
26+
if word not in self.words:
27+
self.words[word] = 1
28+
else:
29+
self.words[word] += 1
30+
31+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
32+
last_idx = 0
33+
is_id = False
34+
res = []
35+
36+
for i, c in monit.enum('Collect words', data, is_silent=is_silent):
37+
if c in ID_CHARS:
38+
if not is_id:
39+
if last_idx < i:
40+
res.append(data[last_idx:i])
41+
last_idx = i
42+
is_id = True
43+
else:
44+
if is_id:
45+
if last_idx < i:
46+
res.append(data[last_idx:i])
47+
last_idx = i
48+
is_id = False
49+
50+
if last_idx < len(data):
51+
res.append(data[last_idx:])
52+
53+
return res
54+
55+
def collect_words(self, data: str):
56+
last_idx = 0
57+
is_id = False
58+
59+
for i, c in monit.enum('Collect words', data):
60+
if c in ID_CHARS:
61+
if not is_id:
62+
self.add_word(data[last_idx:i])
63+
last_idx = i
64+
is_id = True
65+
else:
66+
if is_id:
67+
self.add_word(data[last_idx:i])
68+
last_idx = i
69+
is_id = False
70+
71+
self.add_word(data[last_idx:])
72+
73+
def get_words(self):
74+
words_list = [(f, w) for w, f in self.words.items()]
75+
words_list.sort(key=lambda x: -x[0])
76+
77+
return [w for _, w in words_list], [f for f, _ in words_list]
78+
79+
80+
class NoTokenizer(WordTokenizer):
81+
def __init__(self):
82+
self.data = ''
83+
84+
def collect_words(self, data):
85+
self.data += data
86+
87+
def get_words(self):
88+
return [self.data], [1]
89+
90+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
91+
return [data]

0 commit comments

Comments
 (0)