You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: segmentation_models_pytorch/decoders/deeplabv3/model.py
+30-8
Original file line number
Diff line number
Diff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
35
35
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
36
36
**callable** and **None**.
37
37
Default is **None**
38
-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38
+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
39
39
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
40
40
on top of encoder if **aux_params** is not **None** (default). Supported params:
41
41
- classes (int): A number of classes
42
42
- pooling (str): One of "max", "avg". Default is "avg"
43
43
- dropout (float): Dropout factor in [0, 1)
44
44
- activation (str): An activation function to apply "sigmoid"/"softmax"
45
45
(could be **None** to return logits)
46
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47
+
Keys with ``None`` values are pruned before passing.
47
48
48
49
Returns:
49
50
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,12 @@ def __init__(
72
73
):
73
74
super().__init__()
74
75
76
+
ifencoder_output_stridenotin [8, 16]:
77
+
raiseValueError(
78
+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
@@ -129,16 +144,16 @@ class DeepLabV3Plus(SegmentationModel):
129
144
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
130
145
**callable** and **None**.
131
146
Default is **None**
132
-
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
133
-
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
147
+
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
134
148
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
135
149
on top of encoder if **aux_params** is not **None** (default). Supported params:
136
150
- classes (int): A number of classes
137
151
- pooling (str): One of "max", "avg". Default is "avg"
138
152
- dropout (float): Dropout factor in [0, 1)
139
153
- activation (str): An activation function to apply "sigmoid"/"softmax"
140
154
(could be **None** to return logits)
141
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
155
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
156
+
Keys with ``None`` values are pruned before passing.
142
157
143
158
Returns:
144
159
``torch.nn.Module``: **DeepLabV3Plus**
@@ -167,6 +182,13 @@ def __init__(
167
182
):
168
183
super().__init__()
169
184
185
+
ifencoder_output_stridenotin [8, 16]:
186
+
raiseValueError(
187
+
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
0 commit comments