Skip to content

Commit 6ea0b98

Browse files
all trainable / loadable
1 parent 59f1c30 commit 6ea0b98

28 files changed

+41
-118
lines changed

doctr/file_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,6 @@
3535
logging.info("Disabling PyTorch because USE_TF is set")
3636
_torch_available = False
3737

38-
# Compatibility fix to make sure tensorflow.keras stays at Keras 2
39-
if "TF_USE_LEGACY_KERAS" not in os.environ:
40-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
41-
42-
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
43-
raise ValueError(
44-
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
45-
)
46-
47-
48-
def ensure_keras_v2() -> None: # pragma: no cover
49-
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
50-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
51-
5238

5339
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
5440
_tf_available = importlib.util.find_spec("tensorflow") is not None
@@ -79,7 +65,6 @@ def ensure_keras_v2() -> None: # pragma: no cover
7965
_tf_available = False
8066
else:
8167
logging.info(f"TensorFlow version {_tf_version} available.")
82-
ensure_keras_v2()
8368
import tensorflow as tf
8469

8570
# Enable eager execution - this is required for some models to work properly

doctr/models/classification/magc_resnet/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"std": (0.299, 0.296, 0.301),
2727
"input_shape": (32, 32, 3),
2828
"classes": list(VOCABS["french"]),
29-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
29+
"url": None,
3030
},
3131
}
3232

doctr/models/classification/mobilenet/tensorflow.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,42 +32,42 @@
3232
"std": (0.299, 0.296, 0.301),
3333
"input_shape": (32, 32, 3),
3434
"classes": list(VOCABS["french"]),
35-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
35+
"url": None,
3636
},
3737
"mobilenet_v3_large_r": {
3838
"mean": (0.694, 0.695, 0.693),
3939
"std": (0.299, 0.296, 0.301),
4040
"input_shape": (32, 32, 3),
4141
"classes": list(VOCABS["french"]),
42-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
42+
"url": None,
4343
},
4444
"mobilenet_v3_small": {
4545
"mean": (0.694, 0.695, 0.693),
4646
"std": (0.299, 0.296, 0.301),
4747
"input_shape": (32, 32, 3),
4848
"classes": list(VOCABS["french"]),
49-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
49+
"url": None,
5050
},
5151
"mobilenet_v3_small_r": {
5252
"mean": (0.694, 0.695, 0.693),
5353
"std": (0.299, 0.296, 0.301),
5454
"input_shape": (32, 32, 3),
5555
"classes": list(VOCABS["french"]),
56-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
56+
"url": None,
5757
},
5858
"mobilenet_v3_small_crop_orientation": {
5959
"mean": (0.694, 0.695, 0.693),
6060
"std": (0.299, 0.296, 0.301),
6161
"input_shape": (128, 128, 3),
6262
"classes": [0, -90, 180, 90],
63-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
63+
"url": None,
6464
},
6565
"mobilenet_v3_small_page_orientation": {
6666
"mean": (0.694, 0.695, 0.693),
6767
"std": (0.299, 0.296, 0.301),
6868
"input_shape": (512, 512, 3),
6969
"classes": [0, -90, 180, 90],
70-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
70+
"url": None,
7171
},
7272
}
7373

doctr/models/classification/resnet/tensorflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,35 @@
2424
"std": (0.299, 0.296, 0.301),
2525
"input_shape": (32, 32, 3),
2626
"classes": list(VOCABS["french"]),
27-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
27+
"url": None,
2828
},
2929
"resnet31": {
3030
"mean": (0.694, 0.695, 0.693),
3131
"std": (0.299, 0.296, 0.301),
3232
"input_shape": (32, 32, 3),
3333
"classes": list(VOCABS["french"]),
34-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
34+
"url": None,
3535
},
3636
"resnet34": {
3737
"mean": (0.694, 0.695, 0.693),
3838
"std": (0.299, 0.296, 0.301),
3939
"input_shape": (32, 32, 3),
4040
"classes": list(VOCABS["french"]),
41-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
41+
"url": None,
4242
},
4343
"resnet50": {
4444
"mean": (0.694, 0.695, 0.693),
4545
"std": (0.299, 0.296, 0.301),
4646
"input_shape": (32, 32, 3),
4747
"classes": list(VOCABS["french"]),
48-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
48+
"url": None,
4949
},
5050
"resnet34_wide": {
5151
"mean": (0.694, 0.695, 0.693),
5252
"std": (0.299, 0.296, 0.301),
5353
"input_shape": (32, 32, 3),
5454
"classes": list(VOCABS["french"]),
55-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
55+
"url": None,
5656
},
5757
}
5858

doctr/models/classification/textnet/tensorflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@
2222
"std": (0.299, 0.296, 0.301),
2323
"input_shape": (32, 32, 3),
2424
"classes": list(VOCABS["french"]),
25-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
25+
"url": None,
2626
},
2727
"textnet_small": {
2828
"mean": (0.694, 0.695, 0.693),
2929
"std": (0.299, 0.296, 0.301),
3030
"input_shape": (32, 32, 3),
3131
"classes": list(VOCABS["french"]),
32-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
32+
"url": None,
3333
},
3434
"textnet_base": {
3535
"mean": (0.694, 0.695, 0.693),
3636
"std": (0.299, 0.296, 0.301),
3737
"input_shape": (32, 32, 3),
3838
"classes": list(VOCABS["french"]),
39-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
39+
"url": None,
4040
},
4141
}
4242

doctr/models/classification/vgg/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"std": (1.0, 1.0, 1.0),
2323
"input_shape": (32, 32, 3),
2424
"classes": list(VOCABS["french"]),
25-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
25+
"url": None,
2626
},
2727
}
2828

doctr/models/classification/vit/tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
"std": (0.299, 0.296, 0.301),
2626
"input_shape": (3, 32, 32),
2727
"classes": list(VOCABS["french"]),
28-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
28+
"url": None,
2929
},
3030
"vit_b": {
3131
"mean": (0.694, 0.695, 0.693),
3232
"std": (0.299, 0.296, 0.301),
3333
"input_shape": (32, 32, 3),
3434
"classes": list(VOCABS["french"]),
35-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
35+
"url": None,
3636
},
3737
}
3838

doctr/models/detection/differentiable_binarization/tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
"mean": (0.798, 0.785, 0.772),
2929
"std": (0.264, 0.2749, 0.287),
3030
"input_shape": (1024, 1024, 3),
31-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
31+
"url": None,
3232
},
3333
"db_mobilenet_v3_large": {
3434
"mean": (0.798, 0.785, 0.772),
3535
"std": (0.264, 0.2749, 0.287),
3636
"input_shape": (1024, 1024, 3),
37-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
37+
"url": None,
3838
},
3939
}
4040

doctr/models/detection/fast/tensorflow.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@
2828
"input_shape": (1024, 1024, 3),
2929
"mean": (0.798, 0.785, 0.772),
3030
"std": (0.264, 0.2749, 0.287),
31-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
31+
"url": None,
3232
},
3333
"fast_small": {
3434
"input_shape": (1024, 1024, 3),
3535
"mean": (0.798, 0.785, 0.772),
3636
"std": (0.264, 0.2749, 0.287),
37-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
37+
"url": None,
3838
},
3939
"fast_base": {
4040
"input_shape": (1024, 1024, 3),
4141
"mean": (0.798, 0.785, 0.772),
4242
"std": (0.264, 0.2749, 0.287),
43-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
43+
"url": None,
4444
},
4545
}
4646

@@ -342,9 +342,6 @@ def _fast(
342342
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
343343
)
344344

345-
# Build the model for reparameterization to access the layers
346-
_ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
347-
348345
return model
349346

350347

doctr/models/detection/linknet/tensorflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@
2626
"mean": (0.798, 0.785, 0.772),
2727
"std": (0.264, 0.2749, 0.287),
2828
"input_shape": (1024, 1024, 3),
29-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
29+
"url": None,
3030
},
3131
"linknet_resnet34": {
3232
"mean": (0.798, 0.785, 0.772),
3333
"std": (0.264, 0.2749, 0.287),
3434
"input_shape": (1024, 1024, 3),
35-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
35+
"url": None,
3636
},
3737
"linknet_resnet50": {
3838
"mean": (0.798, 0.785, 0.772),
3939
"std": (0.264, 0.2749, 0.287),
4040
"input_shape": (1024, 1024, 3),
41-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
41+
"url": None,
4242
},
4343
}
4444

doctr/models/factory/hub.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
if is_torch_available():
2929
import torch
3030
elif is_tf_available():
31-
import tensorflow as tf
31+
pass
3232

3333
__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
3434

@@ -76,8 +76,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
7676
torch.save(model.state_dict(), weights_path)
7777
elif is_tf_available():
7878
weights_path = save_directory / "tf_model.weights.h5"
79-
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
80-
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
8179
model.save_weights(str(weights_path))
8280

8381
config_path = save_directory / "config.json"
@@ -229,8 +227,6 @@ def from_hub(repo_id: str, **kwargs: Any):
229227
model.load_state_dict(state_dict)
230228
else: # tf
231229
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
232-
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
233-
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
234230
model.load_weights(weights)
235231

236232
return model

doctr/models/recognition/crnn/tensorflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
"mean": (0.694, 0.695, 0.693),
2424
"std": (0.299, 0.296, 0.301),
2525
"input_shape": (32, 128, 3),
26-
"vocab": VOCABS["legacy_french"],
27-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
26+
"vocab": VOCABS["french"],
27+
"url": None,
2828
},
2929
"crnn_mobilenet_v3_small": {
3030
"mean": (0.694, 0.695, 0.693),
3131
"std": (0.299, 0.296, 0.301),
3232
"input_shape": (32, 128, 3),
3333
"vocab": VOCABS["french"],
34-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
34+
"url": None,
3535
},
3636
"crnn_mobilenet_v3_large": {
3737
"mean": (0.694, 0.695, 0.693),
3838
"std": (0.299, 0.296, 0.301),
3939
"input_shape": (32, 128, 3),
4040
"vocab": VOCABS["french"],
41-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
41+
"url": None,
4242
},
4343
}
4444

doctr/models/recognition/master/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"std": (0.299, 0.296, 0.301),
2626
"input_shape": (32, 128, 3),
2727
"vocab": VOCABS["french"],
28-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
28+
"url": None,
2929
},
3030
}
3131

doctr/models/recognition/parseq/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"std": (0.299, 0.296, 0.301),
2828
"input_shape": (32, 128, 3),
2929
"vocab": VOCABS["french"],
30-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
30+
"url": None,
3131
},
3232
}
3333

doctr/models/recognition/sar/tensorflow.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"std": (0.299, 0.296, 0.301),
2525
"input_shape": (32, 128, 3),
2626
"vocab": VOCABS["french"],
27-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
27+
"url": None,
2828
},
2929
}
3030

@@ -170,9 +170,7 @@ def call(
170170
for t in range(self.max_length + 1): # 32
171171
if t == 0:
172172
# step to init the first states of the LSTMCell
173-
states = self.lstm_cells.get_initial_state(
174-
inputs=None, batch_size=features.shape[0], dtype=features.dtype
175-
)
173+
states = self.lstm_cells.get_initial_state(batch_size=features.shape[0])
176174
prev_symbol = holistic
177175
elif t == 1:
178176
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros

doctr/models/recognition/vitstr/tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
"std": (0.299, 0.296, 0.301),
2424
"input_shape": (32, 128, 3),
2525
"vocab": VOCABS["french"],
26-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
26+
"url": None,
2727
},
2828
"vitstr_base": {
2929
"mean": (0.694, 0.695, 0.693),
3030
"std": (0.299, 0.296, 0.301),
3131
"input_shape": (32, 128, 3),
3232
"vocab": VOCABS["french"],
33-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
33+
"url": None,
3434
},
3535
}
3636

doctr/models/utils/tensorflow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ def load_pretrained_params(
5959
else:
6060
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
6161

62-
# Build the model
63-
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
64-
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
65-
6662
# Load weights
6763
model.load_weights(archive_path, skip_mismatch=skip_mismatch)
6864

@@ -125,7 +121,7 @@ class IntermediateLayerGetter(Model):
125121
"""
126122

127123
def __init__(self, model: Model, layer_names: List[str]) -> None:
128-
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
124+
intermediate_fmaps = [model.get_layer(layer_name)._inbound_nodes[0].outputs[0] for layer_name in layer_names]
129125
super().__init__(model.input, outputs=intermediate_fmaps)
130126

131127
def __repr__(self) -> str:

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ dependencies = [
5454
tf = [
5555
# cf. https://github.com/mindee/doctr/pull/1461
5656
"tensorflow>=2.15.0,<3.0.0",
57-
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
5857
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
5958
]
6059
torch = [
@@ -98,7 +97,6 @@ dev = [
9897
# Tensorflow
9998
# cf. https://github.com/mindee/doctr/pull/1461
10099
"tensorflow>=2.15.0,<3.0.0",
101-
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
102100
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
103101
# PyTorch
104102
"torch>=1.12.0,<3.0.0",

0 commit comments

Comments
 (0)