diff --git a/README.md b/README.md index 8873b35..b9a860f 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,8 @@ Load model: ```python from litmodels import load_model -model_ = load_model(name="your_org/your_team/torch-model") +# when loading a Pytorch model a instance of the model is also needed +model_ = load_model(name="your_org/your_team/torch-model", model_instance=model) ``` @@ -100,7 +101,7 @@ trainer.fit(BoringModel()) # Upload the best model to cloud storage checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") # Define the model name - this should be unique to your model -upload_model(model=checkpoint_path, name="//") +upload_model(model=checkpoint_path, name="//lightning-model") ``` Load model: @@ -113,7 +114,7 @@ from litmodels.demos import BoringModel # Load the model from cloud storage checkpoint_path = download_model( # Define the model name and version - this needs to be unique to your model - name="//:", + name="//lightning-model:", download_dir="my_models", ) print(f"model: {checkpoint_path}") @@ -147,7 +148,7 @@ model = keras.Sequential( model.compile(optimizer="adam", loss="categorical_crossentropy") # Save the model -save_model("lightning-ai/jirka/sample-tf-keras-model", model=model) +save_model("//tf-keras-model", model=model) ``` Load model: @@ -156,7 +157,7 @@ Load model: from litmodels import load_model model_ = load_model( - "lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model" + "//tf-keras-model", download_dir="./my-model" ) ``` @@ -185,7 +186,7 @@ model = svm.SVC() model.fit(X_train, y_train) # Upload the saved model using litmodels -save_model(model=model, name="your_org/your_team/sklearn-svm-model") +save_model(model=model, name="//sklearn-model") ``` Use model: @@ -195,7 +196,7 @@ from litmodels import load_model # Download and load the model file from cloud storage model = load_model( - name="your_org/your_team/sklearn-svm-model", download_dir="my_models" + name="//sklearn-model", download_dir="my_models" ) # Example: run inference with the loaded model diff --git a/src/litmodels/io/gateway.py b/src/litmodels/io/gateway.py index 1137a4d..b446599 100644 --- a/src/litmodels/io/gateway.py +++ b/src/litmodels/io/gateway.py @@ -139,7 +139,7 @@ def download_model( ) -def load_model(name: str, download_dir: str = ".") -> Any: +def load_model(name: str, download_dir: str = ".", model_instance: Optional[object] = None) -> Any: """Download a model from the model store and load it into memory. Args: @@ -147,6 +147,8 @@ def load_model(name: str, download_dir: str = ".") -> Any: where entity is either your username or the name of an organization you are part of. download_dir: A path to directory where the model should be downloaded. Defaults to the current working directory. + model_instance: Optional argument needed if loading a pure Pytorch model. Pass in a initialized + instance of the model to load the model weights into. Returns: The loaded model. @@ -159,6 +161,13 @@ def load_model(name: str, download_dir: str = ".") -> Any: model_path = Path(download_dir) / download_paths[0] if model_path.suffix.lower() == ".ts": return torch.jit.load(model_path) + if model_path.suffix.lower() == ".pth": + if model_instance is not None and isinstance(model_instance, torch.nn.Module): + return model_instance.load_state_dict(torch.load(model_path)) + raise ValueError( + "Trying to load a Pure Pytorch model. Expected the optional `model_instance`" + "to be provided with a instance of the saved model to load the model weights into." + ) if model_path.suffix.lower() == ".keras": return keras.models.load_model(model_path) if model_path.suffix.lower() == ".pkl": diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 62c3779..f1e596d 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -138,6 +138,24 @@ def test_load_model_torch_jit(mock_download_model, tmp_path): assert isinstance(model, torch.jit.ScriptModule) +@mock.patch("litmodels.io.cloud.sdk_download_model") +def test_load_model_torch(mock_download_model, tmp_path): + # craete a dummy file + model_file = tmp_path / "dummy_model.pth" + test_data = torch.nn.Module() + torch.save(test_data, model_file) + mock_download_model.return_value = [str(model_file.name)] + + # The lit-logger function is just a wrapper around the SDK function + model = load_model( + name="org-name/teamspace/model-name", download_dir=str(tmp_path), model_instance=torch.nn.Module() + ) + mock_download_model.assert_called_once_with( + name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True + ) + assert isinstance(model, torch.nn.Module) + + @pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available") @mock.patch("litmodels.io.cloud.sdk_download_model") def test_load_model_tf_keras(mock_download_model, tmp_path):