diff --git a/guidance/models/_model.py b/guidance/models/_model.py index e4e950a7b..fdc2200c4 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -613,13 +613,28 @@ def _cleanup_tokens(self, token_ids, token_byte_positions): for i,id in enumerate(joint_token_ids): pos += len(self.tokenizer.tokens[id]) token_byte_positions.append(pos) - + # ugly hack to deal with sentence peice craziness of space hiding after special tokens TODO: figure out how to make this more robust - if token_byte_positions[-1] == last_pos + 1 and self.tokenizer.tokens[token_ids[0]] == b'' and self.tokenizer.tokens[token_ids[1]][0:1] == b' ': - for i in range(1, len(token_byte_positions)): - token_byte_positions[i] -= 1 + if ( + hasattr(self.tokenizer, "_orig_tokenizer") + and hasattr(self.tokenizer._orig_tokenizer, "all_special_tokens") + and token_byte_positions[-1] > last_pos + ): + special_tokens = [ + bytes(s, encoding="utf8") + for s in self.tokenizer._orig_tokenizer.all_special_tokens + ] + spaces_hidden = 0 + for i in range(0, len(token_byte_positions)): + token_byte_positions[i] -= spaces_hidden + if ( + self.tokenizer.tokens[token_ids[i]] in special_tokens + and len(token_ids) > i + 1 + and self.tokenizer.tokens[token_ids[i + 1]][0:1] == b" " + ): + spaces_hidden += 1 assert token_byte_positions[-1] == last_pos - + return token_ids, token_byte_positions def get_logits(self, token_ids, forced_bytes, current_temp): diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index 0464ed6c2..8fcfe958f 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -95,13 +95,14 @@ def _model_and_tokenizer(self, model, tokenizer, **kwargs): def _joint_tokenize(self, token_ids): # first_decode = self.tokenizer._orig_tokenizer.decode(token_ids) first_decode = b''.join([self.tokenizer.tokens[id] for id in token_ids]).decode("utf8") + + # HACK: work around a bug in the HuggingFace tokenizer (that will just add extra spaces during an encode-decode cycle) + if hasattr(self.tokenizer._orig_tokenizer, "all_special_tokens"): + for special_token in self.tokenizer._orig_tokenizer.all_special_tokens: + first_decode = first_decode.replace(f"{special_token} ", special_token) + new_ids = self.tokenizer._orig_tokenizer(first_decode, add_special_tokens=False)["input_ids"] - # HACK: check for a bug in the HuggingFace tokenizer (that will just add extra spaces during an encode-decode cycle) - second_decode = self.tokenizer._orig_tokenizer.decode(new_ids) - if second_decode != first_decode and len(second_decode) == len(first_decode) + 1 and second_decode.startswith(" "): - new_ids = new_ids[0:1] + new_ids[2:] - return new_ids def get_logits(self, token_ids, forced_bytes, current_temp):