Skip to content

Commit 2aa8411

Browse files
Merge pull request #304 from bioimage-io/tf-2
Add support for tensorflow 2
2 parents a075275 + 23ef46e commit 2aa8411

File tree

7 files changed

+127
-54
lines changed

7 files changed

+127
-54
lines changed

.github/workflows/build.yml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ jobs:
6969
shell: bash -l {0}
7070
run: pytest --disable-pytest-warnings
7171

72-
test-base-bioimage-spec-tf-legacy:
72+
test-base-bioimage-spec-tf:
7373
runs-on: ubuntu-latest
7474
strategy:
7575
matrix:
76-
python-version: [3.7]
76+
python-version: [3.7, 3.8, 3.9]
7777
steps:
7878
- uses: actions/checkout@v2
7979
- name: install dependencies
@@ -92,6 +92,33 @@ jobs:
9292
conda remove --force bioimageio.spec
9393
pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io
9494
pip install --no-deps -e .
95+
- name: pytest-base-bioimage-spec-tf
96+
shell: bash -l {0}
97+
run: pytest --disable-pytest-warnings
98+
99+
test-base-bioimage-spec-tf-legacy:
100+
runs-on: ubuntu-latest
101+
strategy:
102+
matrix:
103+
python-version: [3.7]
104+
steps:
105+
- uses: actions/checkout@v2
106+
- name: install dependencies
107+
uses: conda-incubator/setup-miniconda@v2
108+
with:
109+
auto-update-conda: true
110+
# we need mamba to resolve environment-tf
111+
mamba-version: "*"
112+
channel-priority: flexible
113+
activate-environment: bio-core-tf-legacy
114+
environment-file: dev/environment-tf-legacy.yaml
115+
python-version: ${{ matrix.python-version }}
116+
- name: additional setup
117+
shell: bash -l {0}
118+
run: |
119+
conda remove --force bioimageio.spec
120+
pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io
121+
pip install --no-deps -e .
95122
- name: pytest-base-bioimage-spec-tf-legacy
96123
shell: bash -l {0}
97124
run: pytest --disable-pytest-warnings

bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,14 @@ def _forward_tf(self, *input_tensors):
9797

9898
return res
9999

100-
def _forward_keras(self, input_tensors):
100+
def _forward_keras(self, *input_tensors):
101101
tf_tensor = [tf.convert_to_tensor(ipt) for ipt in input_tensors]
102-
result = self._model.forward(*tf_tensor)
102+
103+
try:
104+
result = self._model.forward(*tf_tensor)
105+
except AttributeError:
106+
result = self._model.predict(*tf_tensor)
107+
103108
if not isinstance(result, (tuple, list)):
104109
result = [result]
105110

bioimageio/core/weight_converter/keras/tensorflow.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,28 @@
1111
from tensorflow import saved_model
1212

1313

14+
def _zip_weights(output_path):
15+
zipped_model = f"{output_path}.zip"
16+
# zip the weights
17+
file_paths = []
18+
for folder_names, subfolder, filenames in os.walk(os.path.join(output_path)):
19+
for filename in filenames:
20+
# create complete filepath of file in directory
21+
file_paths.append(os.path.join(folder_names, filename))
22+
23+
with ZipFile(zipped_model, "w") as zip_obj:
24+
for f in file_paths:
25+
# Add file to zip
26+
zip_obj.write(f, os.path.relpath(f, output_path))
27+
28+
try:
29+
shutil.rmtree(output_path)
30+
except Exception:
31+
print("TensorFlow bundled model was not removed after compression")
32+
33+
return zipped_model
34+
35+
1436
# adapted from
1537
# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
1638
def _convert_tf1(keras_weight_path, output_path, input_name, output_name, zip_weights):
@@ -40,26 +62,27 @@ def build_tf_model():
4062
build_tf_model()
4163

4264
if zip_weights:
43-
zipped_model = f"{output_path}.zip"
44-
# zip the weights
45-
file_paths = []
46-
for folder_names, subfolder, filenames in os.walk(os.path.join(output_path)):
47-
for filename in filenames:
48-
# create complete filepath of file in directory
49-
file_paths.append(os.path.join(folder_names, filename))
50-
51-
with ZipFile(zipped_model, "w") as zip_obj:
52-
for f in file_paths:
53-
# Add file to zip
54-
zip_obj.write(f, os.path.relpath(f, output_path))
55-
56-
try:
57-
shutil.rmtree(output_path)
58-
except Exception:
59-
print("TensorFlow bundled model was not removed after compression")
60-
print("TensorFlow model exported to", zipped_model)
61-
else:
62-
print("TensorFlow model exported to", output_path)
65+
output_path = _zip_weights(output_path)
66+
print("TensorFlow model exported to", output_path)
67+
68+
return 0
69+
70+
71+
def _convert_tf2(keras_weight_path, output_path, zip_weights):
72+
try:
73+
# try to build the tf model with the keras import from tensorflow
74+
from tensorflow import keras
75+
except Exception:
76+
# if the above fails try to export with the standalone keras
77+
import keras
78+
79+
model = keras.models.load_model(keras_weight_path)
80+
keras.models.save_model(model, output_path)
81+
82+
if zip_weights:
83+
output_path = _zip_weights(output_path)
84+
print("TensorFlow model exported to", output_path)
85+
6386
return 0
6487

6588

@@ -104,4 +127,4 @@ def convert_weights_to_tensorflow_saved_model_bundle(
104127
)
105128
return _convert_tf1(weight_path, str(path_), model.inputs[0].name, model.outputs[0].name, zip_weights)
106129
else:
107-
raise NotImplementedError("Weight conversion for tensorflow 2 is not yet implemented.")
130+
return _convert_tf2(weight_path, str(path_), zip_weights)

dev/environment-tf-legacy.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: bio-core-tf-legacy
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- black
7+
- bioimageio.spec
8+
- conda-build
9+
- h5py >=2.10,<2.11
10+
- mypy
11+
- pip
12+
- pytest
13+
- python >=3.7,<3.8 # this environment is only available for python 3.7
14+
- xarray
15+
- tensorflow >1.14,<2.0
16+
- tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259
17+
- keras

dev/environment-tf.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ dependencies:
66
- black
77
- bioimageio.spec
88
- conda-build
9-
- h5py >=2.10,<2.11
109
- mypy
1110
- pip
1211
- pytest
13-
- python >=3.7,<3.8 # this environment is only available for python 3.7
12+
- python
1413
- xarray
15-
- tensorflow >1.14,<2.0
14+
- tensorflow >=2.9,<3.0
1615
- tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259
17-
- keras

tests/build_spec/test_build_spec.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from marshmallow import missing
24

35
import bioimageio.spec as spec
@@ -6,6 +8,13 @@
68
from bioimageio.core.resource_io.utils import resolve_source
79
from bioimageio.core.resource_tests import test_model as _test_model
810

11+
try:
12+
import tensorflow
13+
except ImportError:
14+
tf_version = None
15+
else:
16+
tf_version: Optional[str] = ".".join(tensorflow.__version__.split(".")[:2])
17+
918

1019
def _test_build_spec(
1120
spec_path,
@@ -175,18 +184,18 @@ def test_build_spec_onnx(any_onnx_model, tmp_path):
175184

176185
def test_build_spec_keras(any_keras_model, tmp_path):
177186
_test_build_spec(
178-
any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version="1.12"
187+
any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version=tf_version
179188
) # todo: keras for tf 2??
180189

181190

182191
def test_build_spec_tf(any_tensorflow_model, tmp_path):
183192
_test_build_spec(
184-
any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version="1.12"
193+
any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version=tf_version
185194
) # check tf version
186195

187196

188197
def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path):
189-
_test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version="1.12")
198+
_test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version=tf_version)
190199

191200

192201
def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
@@ -220,7 +229,7 @@ def test_build_spec_parent2(unet2d_nuclei_broad_model, tmp_path):
220229

221230
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
222231
_test_build_spec(
223-
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"
232+
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version=tf_version
224233
)
225234

226235

tests/conftest.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,20 @@
1919
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
2020
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
2121
tensorflow1_models = ["stardist"]
22-
tensorflow2_models = []
23-
keras_models = ["unet2d_keras"]
22+
tensorflow2_models = ["unet2d_keras_tf2"]
23+
keras_tf1_models = ["unet2d_keras"]
24+
keras_tf2_models = ["unet2d_keras_tf2"]
2425
tensorflow_js_models = []
2526

2627
model_sources = {
2728
"unet2d_keras": (
2829
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2930
"unet2d_keras_tf/rdf.yaml"
3031
),
32+
"unet2d_keras_tf2": (
33+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
34+
"unet2d_keras_tf2/rdf.yaml"
35+
),
3136
"unet2d_nuclei_broad_model": (
3237
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3338
"unet2d_nuclei_broad/rdf.yaml"
@@ -91,12 +96,6 @@
9196
skip_tensorflow = tensorflow is None
9297
skip_tensorflow_js = True # TODO: add a tensorflow_js example model
9398

94-
try:
95-
import keras
96-
except ImportError:
97-
keras = None
98-
skip_keras = keras is None
99-
10099
# load all model packages we need for testing
101100
load_model_packages = set()
102101
if not skip_torch:
@@ -106,13 +105,14 @@
106105
load_model_packages |= set(onnx_models)
107106

108107
if not skip_tensorflow:
109-
load_model_packages |= set(keras_models)
110108
load_model_packages |= set(tensorflow_js_models)
111109
if tf_major_version == 1:
110+
load_model_packages |= set(keras_tf1_models)
112111
load_model_packages |= set(tensorflow1_models)
113112
load_model_packages.add("stardist_wrong_shape")
114113
load_model_packages.add("stardist_wrong_shape2")
115114
elif tf_major_version == 2:
115+
load_model_packages |= set(keras_tf2_models)
116116
load_model_packages |= set(tensorflow2_models)
117117

118118

@@ -146,14 +146,12 @@ def any_onnx_model(request):
146146
return pytest.model_packages[request.param]
147147

148148

149-
@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models)))
149+
@pytest.fixture(params=[] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models)
150150
def any_tensorflow_model(request):
151-
name = request.param
152-
if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models):
153-
return pytest.model_packages[name]
151+
return pytest.model_packages[request.param]
154152

155153

156-
@pytest.fixture(params=[] if skip_keras else keras_models)
154+
@pytest.fixture(params=[] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models)
157155
def any_keras_model(request):
158156
return pytest.model_packages[request.param]
159157

@@ -178,21 +176,17 @@ def any_model(request):
178176
#
179177

180178

181-
@pytest.fixture(
182-
params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]
183-
)
179+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"])
184180
def unet2d_fixed_shape_or_not(request):
185181
return pytest.model_packages[request.param]
186182

187183

188-
@pytest.fixture(
189-
params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
190-
)
184+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"])
191185
def unet2d_multi_tensor_or_not(request):
192186
return pytest.model_packages[request.param]
193187

194188

195-
@pytest.fixture(params=[] if skip_keras else ["unet2d_keras"])
189+
@pytest.fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"])
196190
def unet2d_keras(request):
197191
return pytest.model_packages[request.param]
198192

0 commit comments

Comments
 (0)