diff --git a/dp/model/model.py b/dp/model/model.py index 5709dcb..9fb321a 100644 --- a/dp/model/model.py +++ b/dp/model/model.py @@ -1,6 +1,10 @@ from abc import ABC, abstractmethod from enum import Enum +import os +import requests from typing import Tuple, Dict, Any +from cached_path import cached_path +import validators import torch import torch.nn as nn @@ -9,6 +13,8 @@ from dp.model.utils import get_dedup_tokens, _make_len_mask, _generate_square_subsequent_mask, PositionalEncoding from dp.preprocessing.text import Preprocessor +DEFAULT_MODEL_BUCKET = 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer' + class ModelType(Enum): TRANSFORMER = 'transformer' @@ -290,12 +296,12 @@ def create_model(model_type: ModelType, config: Dict[str, Any]) -> Model: return model -def load_checkpoint(checkpoint_path: str, device: str = 'cpu') -> Tuple[Model, Dict[str, Any]]: +def load_checkpoint(checkpoint: str, device: str = 'cpu') -> Tuple[Model, Dict[str, Any]]: """ - Initializes a model from a checkpoint (.pt file). + Initializes a model from a checkpoint (.pt file). If the checkpoint doesn't exist, it is downloaded to a cache. Args: - checkpoint_path (str): Path to checkpoint file (.pt). + checkpoint (str): Path to checkpoint file (.pt) or name of pre-trained model (.pt) or URL to checkpoint (.pt) device (str): Device to put the model to ('cpu' or 'cuda'). Returns: Tuple: The first element is a Model (the loaded model) @@ -303,10 +309,17 @@ def load_checkpoint(checkpoint_path: str, device: str = 'cpu') -> Tuple[Model, D """ device = torch.device(device) + if not checkpoint[-3:] == '.pt': + raise ValueError(f'{checkpoint} is not a valid model file (.pt).') + if not os.path.exists(checkpoint) and not validators.url(checkpoint): + model_pt_name = os.path.basename(checkpoint) + checkpoint = f"{DEFAULT_MODEL_BUCKET}/{model_pt_name}" + checkpoint_path = cached_path(checkpoint) + print(f"Loading model from {checkpoint_path} ...") checkpoint = torch.load(checkpoint_path, map_location=device) model_type = checkpoint['config']['model']['type'] model_type = ModelType(model_type) model = create_model(model_type, config=checkpoint['config']) model.load_state_dict(checkpoint['model']) model.eval() - return model, checkpoint \ No newline at end of file + return model, checkpoint diff --git a/dp/phonemizer.py b/dp/phonemizer.py index 37e8c85..b7a8c60 100644 --- a/dp/phonemizer.py +++ b/dp/phonemizer.py @@ -191,7 +191,7 @@ def from_checkpoint(cls, """Initializes a Phonemizer object from a model checkpoint (.pt file). Args: - checkpoint_path (str): Path to the .pt checkpoint file. + checkpoint_path (str): Path to checkpoint file (.pt), name of pre-trained model (.pt), or checkpoint URL (.pt) device (str): Device to send the model to ('cpu' or 'cuda'). (Default value = 'cpu') lang_phoneme_dict (Dict[str, Dict[str, str]], optional): Word-phoneme dictionary for each language. diff --git a/requirements.txt b/requirements.txt index fecc28a..33bae93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,6 @@ PyYAML>=5.1 tensorboard certifi>=2022.12.7 wheel>=0.38.0 -setuptools>=65.5.1 \ No newline at end of file +setuptools>=65.5.1 +validators>=0.22.0 +cached-path>=1.5.0 diff --git a/setup.py b/setup.py index e879f2b..615218b 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ long_description_content_type='text/x-rst', license='MIT', install_requires=['torch>=1.2.0', 'tqdm>=4.38.0', 'PyYAML>=5.1', 'tensorboard', - 'certifi>=2022.12.7', 'wheel>=0.38.0', 'setuptools>=65.5.1'], + 'certifi>=2022.12.7', 'wheel>=0.38.0', 'setuptools>=65.5.1', + 'cached-path>=1.5.0', 'validators>=0.22.0'], extras_require={ 'tests': ['pytest-cov'], 'docs': ['mkdocs', 'mkdocs-material'],