File tree 3 files changed +21
-8
lines changed
bioimageio/core/weight_converter/torch 3 files changed +21
-8
lines changed Original file line number Diff line number Diff line change 4
4
from typing import Any , List , Sequence , cast
5
5
6
6
import numpy as np
7
- import torch
8
7
from numpy .testing import assert_array_almost_equal
9
8
10
9
from bioimageio .spec import load_description
14
13
from ...digest_spec import get_member_id , get_test_inputs
15
14
from ...weight_converter .torch ._utils import load_torch_model
16
15
16
+ try :
17
+ import torch
18
+ except ImportError :
19
+ torch = None
20
+
17
21
18
22
def add_onnx_weights (
19
23
model_spec : "str | Path | v0_4.ModelDescr | v0_5.ModelDescr" ,
@@ -48,6 +52,7 @@ def add_onnx_weights(
48
52
"The provided model does not have weights in the pytorch state dict format"
49
53
)
50
54
55
+ assert torch is not None
51
56
with torch .no_grad ():
52
57
53
58
sample = get_test_inputs (model_spec )
Original file line number Diff line number Diff line change 3
3
from typing import List , Sequence , Union
4
4
5
5
import numpy as np
6
- import torch
7
6
from numpy .testing import assert_array_almost_equal
8
7
from typing_extensions import Any , assert_never
9
8
12
11
13
12
from ._utils import load_torch_model
14
13
14
+ try :
15
+ import torch
16
+ except ImportError :
17
+ torch = None
18
+
15
19
16
20
# FIXME: remove Any
17
21
def _check_predictions (
18
22
model : Any ,
19
23
scripted_model : Any ,
20
24
model_spec : "v0_4.ModelDescr | v0_5.ModelDescr" ,
21
- input_data : Sequence [torch .Tensor ],
25
+ input_data : Sequence [" torch.Tensor" ],
22
26
):
27
+ assert torch is not None
28
+
23
29
def _check (input_ : Sequence [torch .Tensor ]) -> None :
24
30
expected_tensors = model (* input_ )
25
31
if isinstance (expected_tensors , torch .Tensor ):
Original file line number Diff line number Diff line change 1
1
from typing import Union
2
2
3
- import torch
4
-
5
3
from bioimageio .core .model_adapters ._pytorch_model_adapter import PytorchModelAdapter
6
4
from bioimageio .spec .model import v0_4 , v0_5
7
5
from bioimageio .spec .utils import download
8
6
7
+ try :
8
+ import torch
9
+ except ImportError :
10
+ torch = None
11
+
9
12
10
13
# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
11
14
# and for each weight format
12
15
def load_torch_model ( # pyright: ignore[reportUnknownParameterType]
13
16
node : Union [v0_4 .PytorchStateDictWeightsDescr , v0_5 .PytorchStateDictWeightsDescr ],
14
17
):
18
+ assert torch is not None
15
19
model = ( # pyright: ignore[reportUnknownVariableType]
16
20
PytorchModelAdapter .get_network (node )
17
21
)
18
- state = torch .load ( # pyright: ignore[reportUnknownVariableType]
19
- download (node .source ).path , map_location = "cpu"
20
- )
22
+ state = torch .load (download (node .source ).path , map_location = "cpu" )
21
23
model .load_state_dict (state ) # FIXME: check incompatible keys?
22
24
return model .eval () # pyright: ignore[reportUnknownVariableType]
You can’t perform that action at this time.
0 commit comments