Skip to content

Commit 7f8d2ac

Browse files
committed
[cm] Small bug fixes
1 parent fe90dd7 commit 7f8d2ac

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

neutone_sdk/non_realtime_wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
108108

109109
self.n_numerical_params = self.n_cont_params + self.n_cat_params
110110

111+
assert self.get_default_param_values().size(0) == self.n_numerical_params, (
112+
f"Default parameter values tensor first dimension must have the same "
113+
f"size as the number of numerical parameters. Expected size "
114+
f"{self.n_numerical_params}, got {self.get_default_param_values().size(0)}"
115+
)
111116
assert self.n_numerical_params <= constants.NEUTONE_GEN_N_NUMERICAL_PARAMS, (
112117
f"Too many numerical (continuous and categorical) parameters. "
113118
f"Max allowed is {constants.NEUTONE_GEN_N_NUMERICAL_PARAMS}"
@@ -146,10 +151,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
146151
]
147152

148153
# TODO(cm): this statement will also be removed once core is refactored
149-
assert (
150-
len(self.get_default_param_names())
151-
== constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
152-
)
154+
assert len(self.get_default_param_names()) == self.n_numerical_params
153155

154156
assert all(
155157
1 <= n <= 2 for n in self.get_audio_in_channels()
@@ -194,7 +196,6 @@ def _create_default_param_values(self) -> Tensor:
194196
elif p.type == NeutoneParameterType.CATEGORICAL:
195197
# Convert to float to match the type of the continuous parameters
196198
numerical_default_values.append(float(p.default_value))
197-
assert len(numerical_default_values) == self.n_numerical_params
198199
numerical_default_values = tr.tensor(numerical_default_values)
199200
return numerical_default_values
200201

0 commit comments

Comments
 (0)