diff --git a/train.py b/train.py index 2b5a98bd..6e5d20d0 100644 --- a/train.py +++ b/train.py @@ -196,13 +196,13 @@ def train(): use_fast=False, ) special_tokens_dict = dict() - if tokenizer.pad_token is None: + if not tokenizer.pad_token: special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN - if tokenizer.eos_token is None: + if not tokenizer.eos_token: special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN - if tokenizer.bos_token is None: + if not tokenizer.bos_token: special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN - if tokenizer.unk_token is None: + if not tokenizer.unk_token: special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN smart_tokenizer_and_embedding_resize(