Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ Load model:
```python
from litmodels import load_model

model_ = load_model(name="your_org/your_team/torch-model")
# when loading a Pytorch model a instance of the model is also needed
model_ = load_model(name="your_org/your_team/torch-model", model_instance=model)
```

</details>
Expand All @@ -100,7 +101,7 @@ trainer.fit(BoringModel())
# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
# Define the model name - this should be unique to your model
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/<model-name>")
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/lightning-model")
```

Load model:
Expand All @@ -113,7 +114,7 @@ from litmodels.demos import BoringModel
# Load the model from cloud storage
checkpoint_path = download_model(
# Define the model name and version - this needs to be unique to your model
name="<organization>/<teamspace>/<model-name>:<model-version>",
name="<organization>/<teamspace>/lightning-model:<model-version>",
download_dir="my_models",
)
print(f"model: {checkpoint_path}")
Expand Down Expand Up @@ -147,7 +148,7 @@ model = keras.Sequential(
model.compile(optimizer="adam", loss="categorical_crossentropy")

# Save the model
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
save_model("<organization>/<teamspace>/tf-keras-model", model=model)
```

Load model:
Expand All @@ -156,7 +157,7 @@ Load model:
from litmodels import load_model

model_ = load_model(
"lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model"
"<organization>/<teamspace>/tf-keras-model", download_dir="./my-model"
)
```

Expand Down Expand Up @@ -185,7 +186,7 @@ model = svm.SVC()
model.fit(X_train, y_train)

# Upload the saved model using litmodels
save_model(model=model, name="your_org/your_team/sklearn-svm-model")
save_model(model=model, name="<organization>/<teamspace>/sklearn-model")
```

Use model:
Expand All @@ -195,7 +196,7 @@ from litmodels import load_model

# Download and load the model file from cloud storage
model = load_model(
name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
name="<organization>/<teamspace>/sklearn-model", download_dir="my_models"
)

# Example: run inference with the loaded model
Expand Down
11 changes: 10 additions & 1 deletion src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ def download_model(
)


def load_model(name: str, download_dir: str = ".") -> Any:
def load_model(name: str, download_dir: str = ".", model_instance: Optional[object] = None) -> Any:
"""Download a model from the model store and load it into memory.

Args:
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
download_dir: A path to directory where the model should be downloaded. Defaults
to the current working directory.
model_instance: Optional argument needed if loading a pure Pytorch model. Pass in a initialized
instance of the model to load the model weights into.

Returns:
The loaded model.
Expand All @@ -159,6 +161,13 @@ def load_model(name: str, download_dir: str = ".") -> Any:
model_path = Path(download_dir) / download_paths[0]
if model_path.suffix.lower() == ".ts":
return torch.jit.load(model_path)
if model_path.suffix.lower() == ".pth":
if model_instance is not None and isinstance(model_instance, torch.nn.Module):
return model_instance.load_state_dict(torch.load(model_path))
raise ValueError(
"Trying to load a Pure Pytorch model. Expected the optional `model_instance`"
"to be provided with a instance of the saved model to load the model weights into."
)
if model_path.suffix.lower() == ".keras":
return keras.models.load_model(model_path)
if model_path.suffix.lower() == ".pkl":
Expand Down
18 changes: 18 additions & 0 deletions tests/test_io_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ def test_load_model_torch_jit(mock_download_model, tmp_path):
assert isinstance(model, torch.jit.ScriptModule)


@mock.patch("litmodels.io.cloud.sdk_download_model")
def test_load_model_torch(mock_download_model, tmp_path):
# craete a dummy file
model_file = tmp_path / "dummy_model.pth"
test_data = torch.nn.Module()
torch.save(test_data, model_file)
mock_download_model.return_value = [str(model_file.name)]

# The lit-logger function is just a wrapper around the SDK function
model = load_model(
name="org-name/teamspace/model-name", download_dir=str(tmp_path), model_instance=torch.nn.Module()
)
mock_download_model.assert_called_once_with(
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
)
assert isinstance(model, torch.nn.Module)


@pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available")
@mock.patch("litmodels.io.cloud.sdk_download_model")
def test_load_model_tf_keras(mock_download_model, tmp_path):
Expand Down
Loading