@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;
1515
1616type Rank = u32 ;
1717
18- fn _byte_pair_merge (
19- ranks : & HashMap < Vec < u8 > , Rank > ,
20- piece : & [ u8 ] ,
21- ) -> Vec < ( usize , Rank ) > {
18+ fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
2219 // This is a vector of (start, rank).
23- // The rank is of the byte pair starting at position start.
24- // The rank of the last item in the vector is not a valid value.
25- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
20+ // The rank is of the pair starting at position start.
21+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
22+
23+ // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
24+ // the way we currently do, this is equivalent. An easy way to break this would be to decouple
25+ // merge priority from token index or to prevent specific token merges.
26+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , usize:: MAX ) ;
27+ for i in 0 ..piece. len ( ) - 1 {
28+ let rank = * ranks. get ( & piece[ i..i + 2 ] ) . unwrap_or ( & Rank :: MAX ) ;
29+ if rank < min_rank. 0 {
30+ min_rank = ( rank, i) ;
31+ }
32+ parts. push ( ( i, rank) ) ;
33+ }
34+ parts. push ( ( piece. len ( ) - 1 , Rank :: MAX ) ) ;
35+ parts. push ( ( piece. len ( ) , Rank :: MAX ) ) ;
2636
2737 let get_rank = {
2838 #[ inline( always) ]
29- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
30- if ( start_idx + skip + 2 ) < parts. len ( ) {
31- ranks
32- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
33- . copied ( )
39+ |parts : & Vec < ( usize , Rank ) > , i : usize | {
40+ if ( i + 3 ) < parts. len ( ) {
41+ // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
42+ // parts[i + 1], see comment in the main loop.
43+ * ranks
44+ . get ( & piece[ parts[ i] . 0 ..parts[ i + 3 ] . 0 ] )
45+ . unwrap_or ( & Rank :: MAX )
3446 } else {
35- None
47+ Rank :: MAX
3648 }
3749 }
3850 } ;
3951
40- // We look up the ranks once in the beginning and iteratively update
41- // them during each merge, which reduces the number of rank lookups.
42- for i in 0 ..parts. len ( ) - 2 {
43- match get_rank ( & parts, i, 0 ) {
44- Some ( rank) => {
45- // Rank::MAX is a sentinel value and cannot be a valid rank
46- debug_assert ! ( rank != Rank :: MAX ) ;
47- parts[ i] . 1 = rank;
48- }
49- None => {
50- continue ;
51- }
52- } ;
53- }
54-
5552 // If you have n parts and m merges, this does O(mn) work.
5653 // We could do something with a heap and do O(m log n) work.
57- // It is important to consider that n is often small (<100), and as such
58- // the cache-locality benefits outweigh the algorithmic complexity downsides
59- // of the `parts` vector data structure above.
60-
61- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
62- // currently do, this is equivalent. An easy way to break this would be to decouple
63- // merge priority from token index or to prevent specific token merges.
64- loop {
65- if parts. len ( ) == 1 {
66- break ;
54+ // n is often very small so considerations like cache-locality outweigh the algorithmic
55+ // complexity downsides of the `parts` vector.
56+ while min_rank. 0 != Rank :: MAX {
57+ let i = min_rank. 1 ;
58+ // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
59+ // `parts.remove(i + 1)` will thrash the cache.
60+ if i > 0 {
61+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 ) ;
6762 }
63+ parts[ i] . 1 = get_rank ( & parts, i) ;
64+ parts. remove ( i + 1 ) ;
6865
69- // Rank::MAX is a sentinel rank value allowing us to
70- // take the min more quickly
71- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
66+ min_rank = ( Rank :: MAX , usize:: MAX ) ;
7267 for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
7368 if rank < min_rank. 0 {
7469 min_rank = ( rank, i) ;
7570 }
7671 }
77-
78- if min_rank. 0 != Rank :: MAX {
79- let i = min_rank. 1 ;
80-
81- // NOTE: We are about to remove parts[i + 1]. We do not do it
82- // yet because there are cache-locality benefits to updating
83- // parts[i] and parts[i-1] before removing, which could thrash
84- // the cache. Thus, we update the rank calculation by skipping over
85- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
87- if i > 0 {
88- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
89- }
90-
91- parts. remove ( i + 1 ) ;
92- } else {
93- break ;
94- }
9572 }
96-
9773 parts
9874}
9975
0 commit comments