Skip to content

Commit 8cf1be8

Browse files
authored
Merge pull request #652 from liyier90/feat-threshold-checker-interval-string
Feat: threshold checker interval string
2 parents ac10c34 + 24dbbc9 commit 8cf1be8

File tree

28 files changed

+317
-398
lines changed

28 files changed

+317
-398
lines changed

lint_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ click == 7.1.2
22
colorama == 0.4.4
33
numpy == 1.17.3
44
opencv-contrib-python >= 4.5.2.54
5+
protobuf <= 3.20.1
56
pyyaml >= 5.3
67
requests == 2.24.0
78
tensorflow == 2.2.0

peekingduck/pipeline/nodes/augment/brightness.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
4646
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
4747
super().__init__(config, node_path=__name__, **kwargs)
4848

49-
self.check_bounds("beta", (-100, 100), "within")
49+
self.check_bounds("beta", "[-100, 100]")
5050

5151
def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
5252
"""Adjusts the brightness of an image frame.

peekingduck/pipeline/nodes/augment/contrast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
4444
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
4545
super().__init__(config, node_path=__name__, **kwargs)
4646

47-
self.check_bounds("alpha", (0, 3), "within")
47+
self.check_bounds("alpha", "[0, 3]")
4848

4949
def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
5050
"""Adjusts the contrast of an image frame.

peekingduck/pipeline/nodes/base.py

+76-161
Original file line numberDiff line numberDiff line change
@@ -17,136 +17,93 @@
1717
import hashlib
1818
import operator
1919
import os
20+
import re
2021
import sys
2122
import zipfile
2223
from pathlib import Path
23-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
24+
from typing import Any, Callable, Dict, List, Optional, Set, Union
2425

2526
import requests
2627
from tqdm import tqdm
2728

2829
BASE_URL = "https://storage.googleapis.com/peekingduck/models"
2930
PEEKINGDUCK_WEIGHTS_SUBDIR = "peekingduck_weights"
3031

31-
Number = Union[float, int]
32-
3332

3433
class ThresholdCheckerMixin:
3534
"""Mixin class providing utility methods for checking validity of config
3635
values, typically thresholds.
3736
"""
3837

39-
def check_bounds(
40-
self,
41-
key: Union[str, List[str]],
42-
value: Union[Number, Tuple[Number, Number]],
43-
method: str,
44-
include: Optional[str] = "both",
45-
) -> None:
46-
"""Checks if the configuration value(s) specified by `key` satisties
38+
interval_pattern = re.compile(
39+
r"^[\[\(]\s*[-+]?(inf|\d*\.?\d+)\s*,\s*[-+]?(inf|\d*\.?\d+)\s*[\]\)]$"
40+
)
41+
42+
def check_bounds(self, key: Union[str, List[str]], interval: str) -> None:
43+
"""Checks if the configuration value(s) specified by `key` satisfies
4744
the specified bounds.
4845
4946
Args:
5047
key (Union[str, List[str]]): The specified key or list of keys.
51-
value (Union[Number, Tuple[Number, Number]]): Either a single
52-
number to specify the upper or lower bound or a tuple of
53-
numbers to specify both the upper and lower bounds.
54-
method (str): The bounds checking methods, one of
55-
{"above", "below", "both"}. If "above", checks if the
56-
configuration value is above the specified `value`. If "below",
57-
checks if the configuration value is below the specified
58-
`value`. If "both", checks if the configuration value is above
59-
`value[0]` and below `value[1]`.
60-
include (Optional[str]): Indicates if the `value` itself should be
61-
included in the bound, one of {"lower", "upper", "both", None}.
62-
Please see Technotes for details.
48+
interval (str): An mathematical interval representing the range of
49+
valid values. The syntax of the `interval` string is:
50+
51+
<value> = <number> | "-inf" | "+inf"
52+
<left_bracket> = "(" | "["
53+
<right_bracket> = ")" | "]"
54+
<interval> = <left_bracket> <value> "," <value> <right_bracket>
55+
56+
See Technotes for more details.
6357
6458
Raises:
6559
TypeError: `key` type is not in (List[str], str).
66-
TypeError: If `value` is not a tuple of only float/int.
67-
TypeError: If `value` is not a tuple with 2 elements.
68-
TypeError: If `value` is not a float, int, or tuple.
69-
TypeError: If `value` type is not a tuple when `method` is
70-
"within".
71-
TypeError: If `value` type is a tuple when `method` is
72-
"above"/"below".
73-
ValueError: If `method` is not one of {"above", "below", "within"}.
60+
ValueError: If `interval` does not match the specified format.
61+
ValueError: If the lower bound is larger than the upper bound.
7462
ValueError: If the configuration value fails the bounds comparison.
7563
7664
Technotes:
77-
The behavior of `include` depends on the specified `method`. The
78-
table below shows the comparison done for various argument
79-
combinations.
80-
81-
+-----------+---------+-------------------------------------+
82-
| method | include | comparison |
83-
+===========+=========+=====================================+
84-
| | "lower" | config[key] >= value |
85-
+ +---------+-------------------------------------+
86-
| | "upper" | config[key] > value |
87-
+ +---------+-------------------------------------+
88-
| | "both" | config[key] >= value |
89-
+ +---------+-------------------------------------+
90-
| "above" | None | config[key] > value |
91-
+-----------+---------+-------------------------------------+
92-
| | "lower" | config[key] < value |
93-
+ +---------+-------------------------------------+
94-
| | "upper" | config[key] <= value |
95-
+ +---------+-------------------------------------+
96-
| | "both" | config[key] <= value |
97-
+ +---------+-------------------------------------+
98-
| "below" | None | config[key] < value |
99-
+-----------+---------+-------------------------------------+
100-
| | "lower" | value[0] <= config[key] < value[1] |
101-
+ +---------+-------------------------------------+
102-
| | "upper" | value[0] < config[key] <= value[1] |
103-
+ +---------+-------------------------------------+
104-
| | "both" | value[0] <= config[key] <= value[1] |
105-
+ +---------+-------------------------------------+
106-
| "within" | None | value[0] < config[key] < value[1] |
107-
+-----------+---------+-------------------------------------+
65+
The table below shows the comparison done for various interval
66+
expressions.
67+
68+
+---------------------+-------------------------------------+
69+
| interval | comparison |
70+
+=====================+=====================================+
71+
| [lower, +inf] | |
72+
+---------------------+ |
73+
| [lower, +inf) | config[key] >= lower |
74+
+---------------------+-------------------------------------+
75+
| (lower, +inf] | |
76+
+---------------------+ |
77+
| (lower, +inf) | config[key] > lower |
78+
+---------------------+-------------------------------------+
79+
| [-inf, upper] | |
80+
+---------------------+ |
81+
| (-inf, upper] | config[key] <= upper |
82+
+---------------------+-------------------------------------+
83+
| [-inf, upper) | |
84+
+---------------------+ |
85+
| (-inf, upper) | config[key] < upper |
86+
+---------------------+-------------------------------------+
87+
| [lower, upper] | lower <= config[key] <= upper |
88+
+---------------------+-------------------------------------+
89+
| (lower, upper] | lower < config[key] <= upper |
90+
+---------------------+-------------------------------------+
91+
| [lower, upper) | lower <= config[key] < upper |
92+
+---------------------+-------------------------------------+
93+
| (lower, upper) | lower < config[key] < upper |
94+
+---------------------+-------------------------------------+
10895
"""
109-
# available checking methods
110-
methods = {"above", "below", "within"}
111-
# available options of lower/upper bound inclusion
112-
lower_includes = {"lower", "both"}
113-
upper_includes = {"upper", "both"}
114-
115-
if method not in methods:
116-
raise ValueError(f"`method` must be one of {methods}")
117-
118-
if isinstance(value, tuple):
119-
if not all(isinstance(val, (float, int)) for val in value):
120-
raise TypeError(
121-
"When using tuple for `value`, it must be a tuple of float/int"
122-
)
123-
if len(value) != 2:
124-
raise ValueError(
125-
"When using tuple for `value`, it must contain only 2 elements"
126-
)
127-
elif isinstance(value, (float, int)):
128-
pass
129-
else:
130-
raise TypeError(
131-
"`value` must be a float/int or tuple, but you passed a "
132-
f"{type(value).__name__}"
133-
)
134-
135-
if method == "within":
136-
if not isinstance(value, tuple):
137-
raise TypeError("`value` must be a tuple when `method` is 'within'")
138-
self._check_within_bounds(
139-
key, value, (include in lower_includes, include in upper_includes)
140-
)
141-
else:
142-
if isinstance(value, tuple):
143-
raise TypeError(
144-
"`value` must be a float/int when `method` is 'above'/'below'"
145-
)
146-
if method == "above":
147-
self._check_above_value(key, value, include in lower_includes)
148-
elif method == "below":
149-
self._check_below_value(key, value, include in upper_includes)
96+
if self.interval_pattern.match(interval) is None:
97+
raise ValueError("Badly formatted interval")
98+
99+
left_bracket = interval[0]
100+
right_bracket = interval[-1]
101+
lower, upper = [float(value.strip()) for value in interval[1:-1].split(",")]
102+
103+
if lower > upper:
104+
raise ValueError("Lower bound cannot be larger than upper bound")
105+
106+
self._check_within_bounds(key, lower, upper, left_bracket, right_bracket)
150107

151108
def check_valid_choice(
152109
self, key: str, choices: Set[Union[int, float, str]]
@@ -167,78 +124,36 @@ def check_valid_choice(
167124
if self.config[key] not in choices:
168125
raise ValueError(f"{key} must be one of {choices}")
169126

170-
def _check_above_value(
171-
self, key: Union[str, List[str]], value: Number, inclusive: bool
172-
) -> None:
173-
"""Checks that configuration values specified by `key` is more than
174-
(or equal to) the specified `value`.
175-
176-
Args:
177-
key (Union[str, List[str]]): The specified key or list of keys.
178-
value (Number): The specified value.
179-
inclusive (bool): If `True`, compares `config[key] >= value`. If
180-
`False`, compares `config[key] > value`.
181-
182-
Raises:
183-
TypeError: `key` type is not in (List[str], str).
184-
ValueError: If the configuration value is less than (or equal to)
185-
`value`.
186-
"""
187-
method = operator.ge if inclusive else operator.gt
188-
extra_reason = " or equal to" if inclusive else ""
189-
self._compare(key, value, method, reason=f"more than{extra_reason} {value}")
190-
191-
def _check_below_value(
192-
self, key: Union[str, List[str]], value: Number, inclusive: bool
193-
) -> None:
194-
"""Checks that configuration values specified by `key` is more than
195-
(or equal to) the specified `value`.
196-
197-
Args:
198-
key (Union[str, List[str]]): The specified key or list of keys.
199-
value (Number): The specified value.
200-
inclusive (bool): If `True`, compares `config[key] <= value`. If
201-
`False`, compares `config[key] < value`.
202-
203-
Raises:
204-
TypeError: `key` type is not in (List[str], str).
205-
ValueError: If the configuration value is less than (or equal to)
206-
`value`.
207-
"""
208-
method = operator.le if inclusive else operator.lt
209-
extra_reason = " or equal to" if inclusive else ""
210-
self._compare(key, value, method, reason=f"less than{extra_reason} {value}")
211-
212-
def _check_within_bounds(
127+
def _check_within_bounds( # pylint: disable=too-many-arguments
213128
self,
214129
key: Union[str, List[str]],
215-
bounds: Tuple[Number, Number],
216-
includes: Tuple[bool, bool],
130+
lower: float,
131+
upper: float,
132+
left_bracket: str,
133+
right_bracket: str,
217134
) -> None:
218135
"""Checks that configuration values specified by `key` is within the
219136
specified bounds between `lower` and `upper`.
220137
221138
Args:
222139
key (Union[str, List[str]]): The specified key or list of keys.
223-
(Union[float, int]): The lower bound.
224-
bounds (Tuple[Number, Number]): The lower and upper bounds.
225-
includes (Tuple[bool, bool]): If `True`, compares `config[key] >= value`.
226-
If `False`, compares `config[key] > value`.
227-
inclusive_upper (bool): If `True`, compares `config[key] <= value`.
228-
If `False`, compares `config[key] < value`.
140+
lower (float): The lower bound.
141+
upper (float): The upper bound.
142+
left_bracket (str): Either a "(" for an open lower bound or a "["
143+
for a closed lower bound.
144+
right_bracket (str): Either a ")" for an open upper bound or a "]"
145+
for a closed upper bound.
229146
230147
Raises:
231148
TypeError: `key` type is not in (List[str], str).
232149
ValueError: If the configuration value is not between `lower` and
233150
`upper`.
234151
"""
235-
method_lower = operator.ge if includes[0] else operator.gt
236-
method_upper = operator.le if includes[1] else operator.lt
237-
reason_lower = "[" if includes[0] else "("
238-
reason_upper = "]" if includes[1] else ")"
239-
reason = f"between {reason_lower}{bounds[0]}, {bounds[1]}{reason_upper}"
240-
self._compare(key, bounds[0], method_lower, reason)
241-
self._compare(key, bounds[1], method_upper, reason)
152+
method_lower = operator.ge if left_bracket == "[" else operator.gt
153+
method_upper = operator.le if right_bracket == "]" else operator.lt
154+
reason = f"between {left_bracket}{lower}, {upper}{right_bracket}"
155+
self._compare(key, lower, method_lower, reason)
156+
self._compare(key, upper, method_upper, reason)
242157

243158
def _compare(
244159
self,

peekingduck/pipeline/nodes/model/csrnetv1/csrnet_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
3535
self.config = config
3636
self.logger = logging.getLogger(__name__)
3737

38-
self.check_bounds("width", 0, "above", include=None)
38+
self.check_bounds("width", "(0, +inf]")
3939

4040
model_dir = self.download_weights()
4141
self.predictor = Predictor(

peekingduck/pipeline/nodes/model/efficientdet_d04/efficientdet_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
3939
self.logger = logging.getLogger(__name__)
4040

4141
self.check_valid_choice("model_type", {0, 1, 2, 3, 4})
42-
self.check_bounds("score_threshold", (0, 1), "within")
42+
self.check_bounds("score_threshold", "[0, 1]")
4343

4444
model_dir = self.download_weights()
4545
classes_path = model_dir / self.weights["classes_file"]

peekingduck/pipeline/nodes/model/fairmotv1/fairmot_model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,8 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
6363
self.config = config
6464
self.logger = logging.getLogger(__name__)
6565

66-
self.check_bounds(
67-
["K", "min_box_area", "track_buffer"], 0, "above", include=None
68-
)
69-
self.check_bounds("score_threshold", (0, 1), "within")
66+
self.check_bounds(["K", "min_box_area", "track_buffer"], "(0, +inf]")
67+
self.check_bounds("score_threshold", "[0, 1]")
7068

7169
model_dir = self.download_weights()
7270
self.tracker = Tracker(

peekingduck/pipeline/nodes/model/hrnetv1/hrnet_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
3333
self.config = config
3434
self.logger = logging.getLogger(__name__)
3535

36-
self.check_bounds("score_threshold", (0, 1), "within")
36+
self.check_bounds("score_threshold", "[0, 1]")
3737

3838
model_dir = self.download_weights()
3939
self.detector = Detector(

peekingduck/pipeline/nodes/model/jdev1/jde_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
6565
self.logger = logging.getLogger(__name__)
6666

6767
self.check_bounds(
68-
["iou_threshold", "nms_threshold", "score_threshold"], (0, 1), "within"
68+
["iou_threshold", "nms_threshold", "score_threshold"], "[0, 1]"
6969
)
7070

7171
model_dir = self.download_weights()

peekingduck/pipeline/nodes/model/movenetv1/movenet_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
3838
{"singlepose_lightning", "singlepose_thunder", "multipose_lightning"},
3939
)
4040
self.check_bounds(
41-
["bbox_score_threshold", "keypoint_score_threshold"], (0, 1), "within"
41+
["bbox_score_threshold", "keypoint_score_threshold"], "[0, 1]"
4242
)
4343

4444
model_dir = self.download_weights()

peekingduck/pipeline/nodes/model/mtcnnv1/mtcnn_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def __init__(self, config: Dict[str, Any]) -> None:
3535
self.config = config
3636
self.logger = logging.getLogger(__name__)
3737

38-
self.check_bounds("min_size", 0, "above", include=None)
38+
self.check_bounds("min_size", "(0, +inf]")
3939
self.check_bounds(
40-
["network_thresholds", "scale_factor", "score_threshold"], (0, 1), "within"
40+
["network_thresholds", "scale_factor", "score_threshold"], "[0, 1]"
4141
)
4242

4343
model_dir = self.download_weights()

peekingduck/pipeline/nodes/model/posenetv1/posenet_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
3838
self.logger = logging.getLogger(__name__)
3939

4040
self.check_valid_choice("model_type", {50, 75, 100, "resnet"})
41-
self.check_bounds("score_threshold", (0, 1), "within")
41+
self.check_bounds("score_threshold", "[0, 1]")
4242

4343
model_dir = self.download_weights()
4444
self.predictor = Predictor(

0 commit comments

Comments
 (0)