-
Notifications
You must be signed in to change notification settings - Fork 21
Description
Hello, I loaded the trained model by
model=BARTModel.from_pretrained('./src/utils/datasets/snips/exp_0_10/bart_word_mask_0.40_checkpoints',checkpoint_file='checkpoint_best.pt',data_name_or_path="./src/utils/datasets/snips/exp_0_10/jointdatabin")
.
but I got this error:
Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/opt/conda/lib/python3.8/site-packages/fairseq/models/bart/model.py", line 115, in from_pretrained x = hub_utils.from_pretrained( File "/opt/conda/lib/python3.8/site-packages/fairseq/hub_utils.py", line 70, in from_pretrained models, args, task = checkpoint_utils.load_model_ensemble_and_task( File "/opt/conda/lib/python3.8/site-packages/fairseq/checkpoint_utils.py", line 279, in load_model_ensemble_and_task state = load_checkpoint_to_cpu(filename, arg_overrides) File "/opt/conda/lib/python3.8/site-packages/fairseq/checkpoint_utils.py", line 232, in load_checkpoint_to_cpu state = _upgrade_state_dict(state) File "/opt/conda/lib/python3.8/site-packages/fairseq/checkpoint_utils.py", line 434, in _upgrade_state_dict registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) KeyError: 'mask_s2s'
what should I do to load the trained model correctly?