diff --git a/tests/generate/tokenizer_adapter_test.py b/tests/generate/tokenizer_adapter_test.py index e7c541d7..1a11bf82 100644 --- a/tests/generate/tokenizer_adapter_test.py +++ b/tests/generate/tokenizer_adapter_test.py @@ -52,6 +52,61 @@ def test_default_tokenizer(self): self.assertEqual(tokenizer.tokenizer_type, 'sentencepiece') self.assertIsNotNone(tokenizer._tokenizer_type, adapter.TokenizerType.SP) + def test_tokenize_hf(self): + model = 'meta-llama/Meta-Llama-3-8B-Instruct' + tokenizer = adapter.Tokenizer( + tokenizer_type='huggingface', tokenizer_path=model + ) + + text = 'hello world' + tokens = tokenizer.tokenize(text, add_eos=True) + + # Check BOS + self.assertEqual(tokens[0], tokenizer.bos_id()) + # Check EOS + self.assertEqual(tokens[-1], tokenizer.eos_id()) + + # Check no double BOS (assuming "hello" doesn't tokenize to BOS) + self.assertNotEqual(tokens[1], tokenizer.bos_id()) + + # Test add_bos=False + tokens_no_bos = tokenizer.tokenize(text, add_bos=False, add_eos=True) + self.assertNotEqual(tokens_no_bos[0], tokenizer.bos_id()) + self.assertEqual(tokens_no_bos[-1], tokenizer.eos_id()) + + # Decode back + decoded = tokenizer.decode(tokens) + # decoded might contain special tokens depending on decode implementation + # TokenizerAdapter.decode calls tokenizer.decode. HF decode usually skips special tokens by default? + # No, skip_special_tokens defaults to False in HF decode? + # Actually TokenizerAdapter.decode just calls self._tokenizer.decode(ids, **kwargs). + # We didn't pass kwargs. + + # Let's just check length or content roughly. + self.assertTrue(len(tokens) > 2) + + def test_special_eos_token(self): + model = 'meta-llama/Meta-Llama-3-8B-Instruct' + # Use a token that definitely exists and is different from default EOS + special_eos = 'world' + tokenizer = adapter.Tokenizer( + tokenizer_type='huggingface', + tokenizer_path=model, + special_eos_token=special_eos, + ) + + text = 'hello' + tokens = tokenizer.tokenize(text, add_eos=True) + + # Check EOS is the special token + self.assertEqual(tokens[-1], tokenizer.eos_id()) + + # Verify the ID matches what we expect from the tokenizer directly + hf_tokenizer = tokenizer.tokenizer + expected_id = hf_tokenizer.convert_tokens_to_ids(special_eos) + self.assertEqual(tokenizer.eos_id(), expected_id) + self.assertNotEqual(tokenizer.eos_id(), hf_tokenizer.eos_token_id) + if __name__ == '__main__': absltest.main() diff --git a/tunix/generate/tokenizer_adapter.py b/tunix/generate/tokenizer_adapter.py index edb134ff..bbf6d53e 100644 --- a/tunix/generate/tokenizer_adapter.py +++ b/tunix/generate/tokenizer_adapter.py @@ -20,7 +20,6 @@ from etils import epath import numpy as np - import sentencepiece as spm @@ -219,6 +218,7 @@ def __init__( add_bos: bool | None = True, add_eos: bool | None = True, hf_access_token: str | None = None, + special_eos_token: str | None = None, ): self.tokenizer_type = tokenizer_type @@ -248,11 +248,26 @@ def __init__( raise ValueError(f'Unsupported tokenizer_type: {tokenizer_type}') super().__init__(tokenizer) + self._special_eos_id = None + if special_eos_token: + if tokenizer_type == 'huggingface': + self._special_eos_id = tokenizer.convert_tokens_to_ids( + special_eos_token + ) + elif tokenizer_type == 'sentencepiece': + self._special_eos_id = tokenizer.PieceToId(special_eos_token) + + def eos_id(self) -> int: + if self._special_eos_id is not None: + return self._special_eos_id + return super().eos_id() + def tokenize( self, example: str, prefix: str = '', suffix: str = '', + add_bos: bool = True, add_eos: bool = True, ) -> np.ndarray: """The tokenization function. @@ -261,6 +276,8 @@ def tokenize( example: Input string to tokenize. prefix: Prefix to add to the input string. suffix: Suffix to add to the input string. + add_bos: If True, add a "beginning of sentence" token at the start of the + output sequence. add_eos: If True, add an "end of sentence" token at the end of the output sequence. @@ -268,15 +285,25 @@ def tokenize( Tokens corresponding to the input string. """ int_list = [] - if self.bos_id(): - int_list.append(self.bos_id()) - if self.tokenizer_type == 'huggingface': - int_list.extend( - self.encode(prefix + example + suffix, add_special_tokens=False) - ) + bos_id = self.bos_id() + if add_bos and bos_id is not None: + int_list.append(bos_id) + + text = prefix + example + suffix + if self._tokenizer_type == TokenizerType.HF: + int_list.extend(self.encode(text, add_special_tokens=False)) else: - # sentencepiece - int_list.extend(self.tokenizer.EncodeAsIds(prefix + example + suffix)) + int_list.extend(self.encode(text)) + + if bos_id is not None: + if add_bos: + # Deduplicate BOS tokens if the tokenizer added one and we added one. + int_list = self.dedup_bos_ids(int_list) + else: + # Remove BOS token if the tokenizer added one. + while int_list and int_list[0] == bos_id: + int_list.pop(0) + if add_eos: int_list.append(self.eos_id()) return np.array(int_list, dtype=np.int32)