Skip to content

Commit b84d33b

Browse files
committed
torch is optional dep
1 parent 2be9913 commit b84d33b

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

bioimageio/core/weight_converter/torch/_onnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, List, Sequence, cast
55

66
import numpy as np
7-
import torch
87
from numpy.testing import assert_array_almost_equal
98

109
from bioimageio.spec import load_description
@@ -14,6 +13,11 @@
1413
from ...digest_spec import get_member_id, get_test_inputs
1514
from ...weight_converter.torch._utils import load_torch_model
1615

16+
try:
17+
import torch
18+
except ImportError:
19+
torch = None
20+
1721

1822
def add_onnx_weights(
1923
model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
@@ -48,6 +52,7 @@ def add_onnx_weights(
4852
"The provided model does not have weights in the pytorch state dict format"
4953
)
5054

55+
assert torch is not None
5156
with torch.no_grad():
5257

5358
sample = get_test_inputs(model_spec)

bioimageio/core/weight_converter/torch/_torchscript.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import List, Sequence, Union
44

55
import numpy as np
6-
import torch
76
from numpy.testing import assert_array_almost_equal
87
from typing_extensions import Any, assert_never
98

@@ -12,14 +11,21 @@
1211

1312
from ._utils import load_torch_model
1413

14+
try:
15+
import torch
16+
except ImportError:
17+
torch = None
18+
1519

1620
# FIXME: remove Any
1721
def _check_predictions(
1822
model: Any,
1923
scripted_model: Any,
2024
model_spec: "v0_4.ModelDescr | v0_5.ModelDescr",
21-
input_data: Sequence[torch.Tensor],
25+
input_data: Sequence["torch.Tensor"],
2226
):
27+
assert torch is not None
28+
2329
def _check(input_: Sequence[torch.Tensor]) -> None:
2430
expected_tensors = model(*input_)
2531
if isinstance(expected_tensors, torch.Tensor):
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
from typing import Union
22

3-
import torch
4-
53
from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
64
from bioimageio.spec.model import v0_4, v0_5
75
from bioimageio.spec.utils import download
86

7+
try:
8+
import torch
9+
except ImportError:
10+
torch = None
11+
912

1013
# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
1114
# and for each weight format
1215
def load_torch_model( # pyright: ignore[reportUnknownParameterType]
1316
node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr],
1417
):
18+
assert torch is not None
1519
model = ( # pyright: ignore[reportUnknownVariableType]
1620
PytorchModelAdapter.get_network(node)
1721
)
18-
state = torch.load( # pyright: ignore[reportUnknownVariableType]
19-
download(node.source).path, map_location="cpu"
20-
)
22+
state = torch.load(download(node.source).path, map_location="cpu")
2123
model.load_state_dict(state) # FIXME: check incompatible keys?
2224
return model.eval() # pyright: ignore[reportUnknownVariableType]

0 commit comments

Comments
 (0)