33
44use std:: collections:: HashSet ;
55
6- use fancy_regex:: Regex ;
76use pyo3:: exceptions;
87use pyo3:: prelude:: * ;
98use pyo3:: PyResult ;
109use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
1110use rustc_hash:: FxHashMap as HashMap ;
1211
13- use crate :: tiktoken :: { byte_pair_encode, CoreBPE , MAX_NUM_THREADS } ;
12+ use tiktoken :: core :: { byte_pair_encode, CoreBPE } ;
1413
1514#[ pyclass]
1615pub struct PyCoreBPE {
@@ -26,47 +25,10 @@ impl PyCoreBPE {
2625 special_tokens_encoder : HashMap < String , usize > ,
2726 pattern : & str ,
2827 ) -> PyResult < Self > {
29- let regex = Regex :: new ( pattern)
30- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?;
31-
32- let special_regex = {
33- let _parts = special_tokens_encoder
34- . keys ( )
35- . map ( |s| fancy_regex:: escape ( s) )
36- . collect :: < Vec < _ > > ( ) ;
37- Regex :: new ( & _parts. join ( "|" ) )
38- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?
39- } ;
40-
41- let decoder: HashMap < usize , Vec < u8 > > =
42- encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
43-
44- assert ! (
45- encoder. len( ) == decoder. len( ) ,
46- "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
47- ) ;
48-
49- let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
50- . iter ( )
51- . map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
52- . collect ( ) ;
53-
54- // Clone because I don't know how to tell Rust I'm not going to change the map
55- let mut sorted_token_bytes: Vec < Vec < u8 > > = encoder. keys ( ) . cloned ( ) . collect ( ) ;
56- sorted_token_bytes. sort ( ) ;
57-
58- let core_bpe = CoreBPE {
59- encoder,
60- special_tokens_encoder,
61- decoder,
62- special_tokens_decoder,
63- regex_tls : ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) ,
64- special_regex_tls : ( 0 ..MAX_NUM_THREADS )
65- . map ( |_| special_regex. clone ( ) )
66- . collect ( ) ,
67- sorted_token_bytes,
68- } ;
69- Ok ( PyCoreBPE { core_bpe } )
28+ println ! ( "encoder: {:?}" , encoder) ;
29+ CoreBPE :: new ( encoder, special_tokens_encoder, pattern)
30+ . map ( |core_bpe| PyCoreBPE { core_bpe } )
31+ . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
7032 }
7133
7234 // ====================
@@ -82,30 +44,7 @@ impl PyCoreBPE {
8244 }
8345
8446 fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < usize > {
85- py. allow_threads ( || {
86- match std:: str:: from_utf8 ( bytes) {
87- Ok ( text) => self . core_bpe . _encode_ordinary_native ( text) ,
88- Err ( e) => {
89- let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
90- let ( tokens, last_piece_token_len) = self . core_bpe . _encode_native ( text, & HashSet :: new ( ) ) ;
91- let ( mut tokens, last_piece_token_len) =
92- self . core_bpe . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
93- if !tokens. is_empty ( ) && last_piece_token_len > 0 {
94- // Lop off the tokens from the last piece and run BPE on the remaining bytes
95- // Somewhat niche, but this may not be correct if we'd have had a regex
96- // split between the valid UTF-8 and the invalid bytes, which is why this
97- // method is private
98- let mut unstable_bytes =
99- self . core_bpe . _decode_native ( & tokens[ tokens. len ( ) - last_piece_token_len..] ) ;
100- unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
101-
102- tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
103- tokens. extend ( byte_pair_encode ( & unstable_bytes, & self . core_bpe . encoder ) ) ;
104- }
105- tokens
106- }
107- }
108- } )
47+ py. allow_threads ( || self . core_bpe . _encode_bytes ( bytes) )
10948 }
11049
11150 fn encode_with_unstable (
@@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
181120mod tests {
182121 use rustc_hash:: FxHashMap as HashMap ;
183122
184- use crate :: tiktoken :: byte_pair_split;
123+ use tiktoken :: core :: byte_pair_split;
185124
186125 #[ test]
187126 fn very_simple_test ( ) {
0 commit comments