diff --git a/README.md b/README.md index e49598a..70250f1 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ see `README_tokenizer.md` for further information. #### Simple use ```python -from convlm.tokenizer import SpokenDialogTokenizer +from turngpt.tokenizer import SpokenDialogTokenizer pretrained_model_name_or_path="microsoft/DialoGPT-small" tokenizer = SpokenDialogTokenizer(pretrained_model_name_or_path) @@ -110,7 +110,7 @@ An un-trained TurnGPT model, loads pre-trained weights by default, and includes ```python from argparse import ArgumentParser - from convlm.turngpt import TurnGPT + from turngpt import TurnGPT parser = ArgumentParser() parser = TurnGPT.add_model_specific_args(parser) diff --git a/turngpt/model.py b/turngpt/model.py index a3b22d8..66db279 100644 --- a/turngpt/model.py +++ b/turngpt/model.py @@ -68,7 +68,7 @@ def idx_to_string(self, idx): idx = idx.item() s = self.tokenizer.convert_ids_to_tokens(idx) s = self.tokenizer.convert_tokens_to_string( - s.strip() + [s.strip()] ) # remove prefix space/symbol return s diff --git a/turngpt/train.py b/turngpt/train.py index 6272450..b5b8018 100644 --- a/turngpt/train.py +++ b/turngpt/train.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl -from datasets_turntaking import DialogTextDM +from datasets_turntaking import ConversationalDM as DialogTextDM from turngpt.model import TurnGPT, TurnGPTWandbCallbacks