Skip to content

Commit fe90dd7

Browse files
committed
[cm] Updating non realtime wrapper with class typing
1 parent a7cd4e0 commit fe90dd7

File tree

1 file changed

+32
-74
lines changed

1 file changed

+32
-74
lines changed

neutone_sdk/non_realtime_wrapper.py

Lines changed: 32 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch as tr
77
from torch import Tensor, nn
88

9-
from neutone_sdk import NeutoneModel, constants, NeutoneParameter, NeutoneParameterType
9+
from neutone_sdk import NeutoneModel, constants, NeutoneParameterType
1010
from neutone_sdk.utils import validate_waveform
1111

1212
logging.basicConfig()
@@ -49,6 +49,16 @@ class NonRealtimeBase(NeutoneModel):
4949
NeutoneParameterType.CATEGORICAL,
5050
NeutoneParameterType.TEXT,
5151
}
52+
# TorchScript typing does not support instance attributes, so we need to type them
53+
# as class attributes. This is required for supporting models with no parameters.
54+
# (https://github.com/pytorch/pytorch/issues/51041#issuecomment-767061194)
55+
cont_param_names: List[str]
56+
cont_param_indices: List[int]
57+
cat_param_names: List[str]
58+
cat_param_indices: List[int]
59+
cat_param_n_values: Dict[str, int]
60+
text_param_max_n_chars: List[int]
61+
text_param_default_values: List[str]
5262

5363
def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
5464
"""
@@ -73,8 +83,6 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
7383
self.text_param_max_n_chars = []
7484
self.text_param_default_values = []
7585

76-
numerical_default_values = []
77-
7886
# We have to keep track of this manually since text params are separate
7987
numerical_param_idx = 0
8088
for p in self.get_neutone_parameters():
@@ -86,33 +94,18 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
8694
self.n_cont_params += 1
8795
self.cont_param_names.append(p.name)
8896
self.cont_param_indices.append(numerical_param_idx)
89-
numerical_default_values.append(p.default_value)
9097
numerical_param_idx += 1
9198
elif p.type == NeutoneParameterType.CATEGORICAL:
9299
self.n_cat_params += 1
93100
self.cat_param_names.append(p.name)
94101
self.cat_param_indices.append(numerical_param_idx)
95102
self.cat_param_n_values[p.name] = p.n_values
96-
# Convert to float since default value tensor must be all the same type
97-
numerical_default_values.append(float(p.default_value))
98103
numerical_param_idx += 1
99104
elif p.type == NeutoneParameterType.TEXT:
100105
self.n_text_params += 1
101106
self.text_param_max_n_chars.append(p.max_n_chars)
102107
self.text_param_default_values.append(p.default_value)
103108

104-
# This is needed for TorchScript typing since it doesn't allow empty lists etc.
105-
if not self.n_cont_params:
106-
self.cont_param_names.append("__torchscript_typing")
107-
self.cont_param_indices.append(-1)
108-
if not self.n_cat_params:
109-
self.cat_param_names.append("__torchscript_typing")
110-
self.cat_param_indices.append(-1)
111-
self.cat_param_n_values["__torchscript_typing"] = -1
112-
if not self.n_text_params:
113-
self.text_param_max_n_chars.append(-1)
114-
self.text_param_default_values.append("__torchscript_typing")
115-
116109
self.n_numerical_params = self.n_cont_params + self.n_cat_params
117110

118111
assert self.n_numerical_params <= constants.NEUTONE_GEN_N_NUMERICAL_PARAMS, (
@@ -126,63 +119,38 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
126119
if self.n_text_params:
127120
self.has_text_param = True
128121

129-
all_neutone_parameters = self._get_all_neutone_parameters()
130-
assert len(all_neutone_parameters) == self._get_max_n_params()
131-
132122
# This overrides the base class definitions to remove the text param or extra
133123
# base param since it is handled separately in the UI.
134124
# TODO(cm): this if statement will be removed once we get rid of the extra
135125
# core methods we don't need anymore
136126
if self.has_text_param:
137127
self.neutone_parameter_names = [
138128
p.name
139-
for p in all_neutone_parameters
129+
for p in self.get_neutone_parameters()
140130
if p.type != NeutoneParameterType.TEXT
141131
]
142132
self.neutone_parameter_descriptions = [
143133
p.description
144-
for p in all_neutone_parameters
134+
for p in self.get_neutone_parameters()
145135
if p.type != NeutoneParameterType.TEXT
146136
]
147137
self.neutone_parameter_types = [
148138
p.type.value
149-
for p in all_neutone_parameters
139+
for p in self.get_neutone_parameters()
150140
if p.type != NeutoneParameterType.TEXT
151141
]
152142
self.neutone_parameter_used = [
153143
p.used
154-
for p in all_neutone_parameters
144+
for p in self.get_neutone_parameters()
155145
if p.type != NeutoneParameterType.TEXT
156146
]
157-
else:
158-
self.neutone_parameter_names = self.neutone_parameter_names[
159-
: constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
160-
]
161-
self.neutone_parameter_descriptions = self.neutone_parameter_descriptions[
162-
: constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
163-
]
164-
self.neutone_parameter_types = self.neutone_parameter_types[
165-
: constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
166-
]
167-
self.neutone_parameter_used = self.neutone_parameter_used[
168-
: constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
169-
]
170147

171148
# TODO(cm): this statement will also be removed once core is refactored
172149
assert (
173150
len(self.get_default_param_names())
174151
== constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
175152
)
176153

177-
numerical_param_default_values_t = Tensor(numerical_default_values)
178-
if self.n_numerical_params > 0:
179-
numerical_param_default_values_t = (
180-
numerical_param_default_values_t.unsqueeze(1)
181-
)
182-
assert numerical_param_default_values_t.size(0) == self.n_numerical_params
183-
# TODO(cm): rename once moved from core
184-
self.register_buffer("default_param_values", numerical_param_default_values_t)
185-
186154
assert all(
187155
1 <= n <= 2 for n in self.get_audio_in_channels()
188156
), "Input audio channels must be mono or stereo"
@@ -212,23 +180,23 @@ def _get_max_n_params(self) -> int:
212180
+ constants.NEUTONE_GEN_N_TEXT_PARAMS
213181
)
214182

215-
def _get_all_neutone_parameters(self) -> List[NeutoneParameter]:
183+
def _create_default_param_values(self) -> Tensor:
216184
"""
217-
Returns a list of NeutoneParameters that is equal to the maximum number of
218-
parameters that the model can have. Placeholder unused NeutoneParameters are
219-
created if the model has less than the maximum number of parameters.
220-
This should not be overwritten by SDK users.
185+
Creates the default parameter values tensor, which must be 1-dimensional.
186+
For NonRealtimeBase models, this is a tensor of the default values for the
187+
continuous and categorical (numerical) parameters and ignores the default
188+
values for the text parameters since these are handled separately.
221189
"""
222-
neutone_parameters = self.get_neutone_parameters()
223-
if len(neutone_parameters) < self._get_max_n_params():
224-
neutone_parameters += [
225-
NeutoneParameter(
226-
name="",
227-
description="",
228-
used=False,
229-
)
230-
] * (self._get_max_n_params() - len(neutone_parameters))
231-
return neutone_parameters
190+
numerical_default_values = []
191+
for p in self.get_neutone_parameters():
192+
if p.type == NeutoneParameterType.CONTINUOUS:
193+
numerical_default_values.append(p.default_value)
194+
elif p.type == NeutoneParameterType.CATEGORICAL:
195+
# Convert to float to match the type of the continuous parameters
196+
numerical_default_values.append(float(p.default_value))
197+
assert len(numerical_default_values) == self.n_numerical_params
198+
numerical_default_values = tr.tensor(numerical_default_values)
199+
return numerical_default_values
232200

233201
@abstractmethod
234202
def get_audio_in_channels(self) -> List[int]:
@@ -424,11 +392,7 @@ def forward(
424392
This method should not be overwritten by SDK users.
425393
"""
426394
if text_params is None:
427-
# Needed for TorchScript typing
428-
if self.n_text_params:
429-
text_params = self.text_param_default_values
430-
else:
431-
text_params = []
395+
text_params = self.text_param_default_values
432396

433397
if self.use_debug_mode:
434398
assert len(audio_in) == len(self.get_audio_in_channels())
@@ -455,13 +419,7 @@ def forward(
455419

456420
if self.use_debug_mode:
457421
if numerical_params is not None:
458-
assert numerical_params.ndim == 2
459-
assert (
460-
self.n_numerical_params
461-
<= numerical_params.size(0)
462-
<= constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
463-
)
464-
assert numerical_params.size(1) == in_n
422+
assert numerical_params.shape == (self.n_numerical_params, in_n)
465423
if not self.is_one_shot_model() and self.get_native_buffer_sizes():
466424
assert (
467425
in_n in self.get_native_buffer_sizes()

0 commit comments

Comments
 (0)