1
- import string
1
+ from functools import lru_cache
2
2
from heapq import heappush , heappop
3
- from typing import List , Tuple
3
+ from typing import List
4
4
5
5
from labml import lab , monit
6
+ from labml .utils .cache import cache_set
7
+ from python_autocomplete .dataset import Tokenizer
8
+ from python_autocomplete .dataset .break_words import SourceCodeTokenizer
6
9
7
- ID_CHARS = set (string .ascii_letters + string .digits + '_' )
8
10
11
+ class BPE (Tokenizer ):
12
+ def __init__ (self , bpe_en_de : 'BPEEnDe' , word_tokenizer ):
13
+ self .bpe = bpe_en_de
14
+ self .word_tokenizer = word_tokenizer
15
+ self .is_trained = True
9
16
10
- class BPE :
11
- def __init__ (self ):
12
- self .char_itos = []
13
- self .char_stoi = {}
14
- self .bpe_itos = []
15
- self .bpe = []
16
- self .common = {}
17
+ @property
18
+ def n_tokens (self ):
19
+ return len (self .bpe .bpe )
17
20
18
- self .bpe_itos = self .calc_bpe_itos ()
21
+ @property
22
+ def itos (self ):
23
+ return self .bpe .bpe_itos
19
24
20
- def to_char_stoi (self , w : str ):
21
- return [self .char_stoi [c ] for c in w ]
25
+ @property
26
+ def stoi (self ):
27
+ return self .bpe .bpe_stoi
22
28
23
- def calc_bpe_itos (self ):
24
- itos = list (self .char_itos )
25
- itos += [itos [p1 ] + itos [p2 ] for p1 , p2 in self .bpe [len (self .char_itos ):]]
26
- return itos
29
+ def encode (self , data : str , * , is_silent : bool = True ):
30
+ words = self .word_tokenizer .tokenize (data , is_silent = is_silent )
27
31
32
+ res = []
33
+ for w in monit .iterate ('Encode words' , words , is_silent = is_silent ):
34
+ res += self .bpe .encode (w )
28
35
29
- class Tokenizer :
30
- def collect_words (self , data : str ):
31
- raise NotImplementedError
36
+ return res
32
37
33
- def get_words (self ) -> Tuple [List [str ], List [int ]]:
34
- raise NotImplementedError
38
+ def __call__ (self , data : str ):
39
+ encoded = self .encode (data )
40
+ return [self .itos [c ] for c in encoded ]
35
41
36
- def tokenize (self , data : str ) -> List [str ]:
37
- raise NotImplementedError
42
+ def rstrip (self , data : str ):
43
+ words = self .word_tokenizer .tokenize (data , is_silent = True )
44
+ words = words [:- 1 ]
45
+ res = []
46
+ for w in words :
47
+ res += self .bpe .encode (w )
38
48
49
+ return '' .join (words ), res
39
50
40
- class SourceCodeTokenizer (Tokenizer ):
41
- def __init__ (self ):
42
- self .words = {}
43
51
44
- def add_word (self , word ):
45
- if not word :
46
- return
52
+ class _BPEEncoder :
53
+ def __init__ (self , pairs ):
54
+ self .pairs = pairs
55
+ self .codes = []
56
+ self .next_idx = []
57
+ self .prev_idx = []
58
+ self .heap = []
47
59
48
- if word not in self .words :
49
- self .words [word ] = 1
50
- else :
51
- self .words [word ] += 1
60
+ def encode (self , codes : List [int ]):
61
+ self .codes = codes
62
+ self .next_idx = BPELearner .default_next_pointers (len (codes ))
63
+ self .prev_idx = BPELearner .default_prev_pointers (len (codes ))
64
+ self .heap = []
52
65
53
- def tokenize (self , data : str ) -> List [str ]:
54
- last_idx = 0
55
- is_id = False
56
- res = []
66
+ for i in range (len (self .codes ) - 1 ):
67
+ self .add_pair ((self .codes [i ], self .codes [i + 1 ]), i )
57
68
58
- for i , c in monit .enum ('Collect words' , data ):
59
- if c in ID_CHARS :
60
- if not is_id :
61
- if last_idx < i :
62
- res .append (data [last_idx :i ])
63
- last_idx = i
64
- is_id = True
65
- else :
66
- if is_id :
67
- if last_idx < i :
68
- res .append (data [last_idx :i ])
69
- last_idx = i
70
- is_id = False
71
-
72
- if last_idx < len (data ):
73
- res .append (data [last_idx :])
69
+ while self .heap :
70
+ _ , idx , pair = heappop (self .heap )
71
+ self .merge (idx , pair )
74
72
75
- return res
73
+ return [ c for c in self . codes if c != - 1 ]
76
74
77
- def collect_words (self , data : str ):
78
- last_idx = 0
79
- is_id = False
75
+ def merge (self , p2 , pair ):
76
+ p3 = self .next_idx [p2 ]
77
+
78
+ if p3 == - 1 or pair [0 ] != self .codes [p2 ] or pair [1 ] != self .codes [p3 ]:
79
+ return
80
80
81
- for i , c in monit .enum ('Collect words' , data ):
82
- if c in ID_CHARS :
83
- if not is_id :
84
- self .add_word (data [last_idx :i ])
85
- last_idx = i
86
- is_id = True
87
- else :
88
- if is_id :
89
- self .add_word (data [last_idx :i ])
90
- last_idx = i
91
- is_id = False
81
+ self .codes [p2 ] = self .pairs [pair ]
82
+ self .codes [p3 ] = - 1
83
+ p1 = self .prev_idx [p2 ]
84
+ p4 = self .next_idx [p3 ]
92
85
93
- self .add_word (data [last_idx :])
86
+ if p1 != - 1 :
87
+ self .add_pair ((self .codes [p1 ], self .codes [p2 ]), p1 )
88
+ self .next_idx [p2 ] = p4
89
+ if p4 != - 1 :
90
+ self .prev_idx [p4 ] = p2
91
+ self .add_pair ((self .codes [p2 ], self .codes [p4 ]), p2 )
94
92
95
- def get_words (self ):
96
- words_list = [( f , w ) for w , f in self .words . items ()]
97
- words_list . sort ( key = lambda x : - x [ 0 ])
93
+ def add_pair (self , pair , idx ):
94
+ if pair not in self .pairs :
95
+ return
98
96
99
- return [ w for _ , w in words_list ], [ f for f , _ in words_list ]
97
+ heappush ( self . heap , ( self . pairs [ pair ], idx , pair ))
100
98
101
99
102
- class NoTokenizer ( Tokenizer ) :
100
+ class BPEEnDe :
103
101
def __init__ (self ):
104
- self .data = ''
102
+ self .char_itos = []
103
+ self .char_stoi = {}
104
+ self .bpe = []
105
+ self .popular_words = {}
106
+
107
+ self .bpe_itos = []
108
+ self .bpe_stoi = {}
109
+ self .pairs = {}
110
+ self .encoder = None
111
+
112
+ def load (self , char_itos , char_stoi , bpe ):
113
+ self .char_itos = char_itos
114
+ self .char_stoi = char_stoi
115
+ self .bpe = bpe
116
+
117
+ self .calc ()
118
+
119
+ def set_popular_words (self , popular_words ):
120
+ self .popular_words = popular_words
121
+
122
+ def calc (self ):
123
+ self .bpe_itos = self .calc_bpe_itos ()
124
+ self .bpe_stoi = {s : i for i , s in enumerate (self .bpe_itos )}
125
+ self .pairs = {(p [0 ], p [1 ]): c for c , p in enumerate (self .bpe ) if not isinstance (p , int )}
105
126
106
- def collect_words (self , data ):
107
- self .data += data
127
+ self .encoder = _BPEEncoder (self .pairs )
108
128
109
- def get_words (self ):
110
- return [self .data ], [ 1 ]
129
+ def to_char_stoi (self , w : str ):
130
+ return [self .char_stoi [ c ] for c in w ]
111
131
112
- def tokenize (self , data : str ) -> List [str ]:
113
- return [data ]
132
+ def calc_bpe_itos (self ):
133
+ itos = list (self .char_itos )
134
+ for p1 , p2 in self .bpe [len (self .char_itos ):]:
135
+ itos .append (itos [p1 ] + itos [p2 ])
136
+ return itos
137
+
138
+ @lru_cache (1024 )
139
+ def encode (self , word : str ):
140
+ if word in self .popular_words :
141
+ return self .popular_words [word ]
142
+
143
+ return self .encoder .encode ([self .char_stoi [c ] for c in word if c in self .char_stoi ])
114
144
115
145
116
146
class BPELearner :
@@ -284,7 +314,7 @@ def main():
284
314
path = lab .get_data_path () / 'train.py'
285
315
286
316
with open (str (path ), 'r' ) as f :
287
- data = f .read ()[: 100_000 ]
317
+ data = f .read ()
288
318
289
319
tokenizer = SourceCodeTokenizer ()
290
320
tokenizer .collect_words (data )
@@ -295,6 +325,15 @@ def main():
295
325
print (bpe .bpe_itos ()[len (bpe .char_itos ):])
296
326
print (len (data ), bpe .get_length ())
297
327
328
+ cache_set ('bpe' , {
329
+ 'char_itos' : bpe .char_itos ,
330
+ 'char_stoi' : bpe .char_stoi ,
331
+ 'bpe' : bpe .bpe
332
+ })
333
+
334
+ bpe_en_de = BPEEnDe ()
335
+ bpe_en_de .load (bpe .char_itos , bpe .char_stoi , bpe .bpe )
336
+
298
337
299
338
if __name__ == '__main__' :
300
339
main ()
0 commit comments