6
6
import torch as tr
7
7
from torch import Tensor , nn
8
8
9
- from neutone_sdk import NeutoneModel , constants , NeutoneParameter , NeutoneParameterType
9
+ from neutone_sdk import NeutoneModel , constants , NeutoneParameterType
10
10
from neutone_sdk .utils import validate_waveform
11
11
12
12
logging .basicConfig ()
@@ -49,6 +49,16 @@ class NonRealtimeBase(NeutoneModel):
49
49
NeutoneParameterType .CATEGORICAL ,
50
50
NeutoneParameterType .TEXT ,
51
51
}
52
+ # TorchScript typing does not support instance attributes, so we need to type them
53
+ # as class attributes. This is required for supporting models with no parameters.
54
+ # (https://github.com/pytorch/pytorch/issues/51041#issuecomment-767061194)
55
+ cont_param_names : List [str ]
56
+ cont_param_indices : List [int ]
57
+ cat_param_names : List [str ]
58
+ cat_param_indices : List [int ]
59
+ cat_param_n_values : Dict [str , int ]
60
+ text_param_max_n_chars : List [int ]
61
+ text_param_default_values : List [str ]
52
62
53
63
def __init__ (self , model : nn .Module , use_debug_mode : bool = True ) -> None :
54
64
"""
@@ -73,8 +83,6 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
73
83
self .text_param_max_n_chars = []
74
84
self .text_param_default_values = []
75
85
76
- numerical_default_values = []
77
-
78
86
# We have to keep track of this manually since text params are separate
79
87
numerical_param_idx = 0
80
88
for p in self .get_neutone_parameters ():
@@ -86,33 +94,18 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
86
94
self .n_cont_params += 1
87
95
self .cont_param_names .append (p .name )
88
96
self .cont_param_indices .append (numerical_param_idx )
89
- numerical_default_values .append (p .default_value )
90
97
numerical_param_idx += 1
91
98
elif p .type == NeutoneParameterType .CATEGORICAL :
92
99
self .n_cat_params += 1
93
100
self .cat_param_names .append (p .name )
94
101
self .cat_param_indices .append (numerical_param_idx )
95
102
self .cat_param_n_values [p .name ] = p .n_values
96
- # Convert to float since default value tensor must be all the same type
97
- numerical_default_values .append (float (p .default_value ))
98
103
numerical_param_idx += 1
99
104
elif p .type == NeutoneParameterType .TEXT :
100
105
self .n_text_params += 1
101
106
self .text_param_max_n_chars .append (p .max_n_chars )
102
107
self .text_param_default_values .append (p .default_value )
103
108
104
- # This is needed for TorchScript typing since it doesn't allow empty lists etc.
105
- if not self .n_cont_params :
106
- self .cont_param_names .append ("__torchscript_typing" )
107
- self .cont_param_indices .append (- 1 )
108
- if not self .n_cat_params :
109
- self .cat_param_names .append ("__torchscript_typing" )
110
- self .cat_param_indices .append (- 1 )
111
- self .cat_param_n_values ["__torchscript_typing" ] = - 1
112
- if not self .n_text_params :
113
- self .text_param_max_n_chars .append (- 1 )
114
- self .text_param_default_values .append ("__torchscript_typing" )
115
-
116
109
self .n_numerical_params = self .n_cont_params + self .n_cat_params
117
110
118
111
assert self .n_numerical_params <= constants .NEUTONE_GEN_N_NUMERICAL_PARAMS , (
@@ -126,63 +119,38 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
126
119
if self .n_text_params :
127
120
self .has_text_param = True
128
121
129
- all_neutone_parameters = self ._get_all_neutone_parameters ()
130
- assert len (all_neutone_parameters ) == self ._get_max_n_params ()
131
-
132
122
# This overrides the base class definitions to remove the text param or extra
133
123
# base param since it is handled separately in the UI.
134
124
# TODO(cm): this if statement will be removed once we get rid of the extra
135
125
# core methods we don't need anymore
136
126
if self .has_text_param :
137
127
self .neutone_parameter_names = [
138
128
p .name
139
- for p in all_neutone_parameters
129
+ for p in self . get_neutone_parameters ()
140
130
if p .type != NeutoneParameterType .TEXT
141
131
]
142
132
self .neutone_parameter_descriptions = [
143
133
p .description
144
- for p in all_neutone_parameters
134
+ for p in self . get_neutone_parameters ()
145
135
if p .type != NeutoneParameterType .TEXT
146
136
]
147
137
self .neutone_parameter_types = [
148
138
p .type .value
149
- for p in all_neutone_parameters
139
+ for p in self . get_neutone_parameters ()
150
140
if p .type != NeutoneParameterType .TEXT
151
141
]
152
142
self .neutone_parameter_used = [
153
143
p .used
154
- for p in all_neutone_parameters
144
+ for p in self . get_neutone_parameters ()
155
145
if p .type != NeutoneParameterType .TEXT
156
146
]
157
- else :
158
- self .neutone_parameter_names = self .neutone_parameter_names [
159
- : constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
160
- ]
161
- self .neutone_parameter_descriptions = self .neutone_parameter_descriptions [
162
- : constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
163
- ]
164
- self .neutone_parameter_types = self .neutone_parameter_types [
165
- : constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
166
- ]
167
- self .neutone_parameter_used = self .neutone_parameter_used [
168
- : constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
169
- ]
170
147
171
148
# TODO(cm): this statement will also be removed once core is refactored
172
149
assert (
173
150
len (self .get_default_param_names ())
174
151
== constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
175
152
)
176
153
177
- numerical_param_default_values_t = Tensor (numerical_default_values )
178
- if self .n_numerical_params > 0 :
179
- numerical_param_default_values_t = (
180
- numerical_param_default_values_t .unsqueeze (1 )
181
- )
182
- assert numerical_param_default_values_t .size (0 ) == self .n_numerical_params
183
- # TODO(cm): rename once moved from core
184
- self .register_buffer ("default_param_values" , numerical_param_default_values_t )
185
-
186
154
assert all (
187
155
1 <= n <= 2 for n in self .get_audio_in_channels ()
188
156
), "Input audio channels must be mono or stereo"
@@ -212,23 +180,23 @@ def _get_max_n_params(self) -> int:
212
180
+ constants .NEUTONE_GEN_N_TEXT_PARAMS
213
181
)
214
182
215
- def _get_all_neutone_parameters (self ) -> List [ NeutoneParameter ] :
183
+ def _create_default_param_values (self ) -> Tensor :
216
184
"""
217
- Returns a list of NeutoneParameters that is equal to the maximum number of
218
- parameters that the model can have. Placeholder unused NeutoneParameters are
219
- created if the model has less than the maximum number of parameters.
220
- This should not be overwritten by SDK users .
185
+ Creates the default parameter values tensor, which must be 1-dimensional.
186
+ For NonRealtimeBase models, this is a tensor of the default values for the
187
+ continuous and categorical (numerical) parameters and ignores the default
188
+ values for the text parameters since these are handled separately .
221
189
"""
222
- neutone_parameters = self . get_neutone_parameters ()
223
- if len ( neutone_parameters ) < self ._get_max_n_params ():
224
- neutone_parameters += [
225
- NeutoneParameter (
226
- name = "" ,
227
- description = "" ,
228
- used = False ,
229
- )
230
- ] * ( self . _get_max_n_params () - len ( neutone_parameters ) )
231
- return neutone_parameters
190
+ numerical_default_values = []
191
+ for p in self .get_neutone_parameters ():
192
+ if p . type == NeutoneParameterType . CONTINUOUS :
193
+ numerical_default_values . append ( p . default_value )
194
+ elif p . type == NeutoneParameterType . CATEGORICAL :
195
+ # Convert to float to match the type of the continuous parameters
196
+ numerical_default_values . append ( float ( p . default_value ))
197
+ assert len ( numerical_default_values ) == self . n_numerical_params
198
+ numerical_default_values = tr . tensor ( numerical_default_values )
199
+ return numerical_default_values
232
200
233
201
@abstractmethod
234
202
def get_audio_in_channels (self ) -> List [int ]:
@@ -424,11 +392,7 @@ def forward(
424
392
This method should not be overwritten by SDK users.
425
393
"""
426
394
if text_params is None :
427
- # Needed for TorchScript typing
428
- if self .n_text_params :
429
- text_params = self .text_param_default_values
430
- else :
431
- text_params = []
395
+ text_params = self .text_param_default_values
432
396
433
397
if self .use_debug_mode :
434
398
assert len (audio_in ) == len (self .get_audio_in_channels ())
@@ -455,13 +419,7 @@ def forward(
455
419
456
420
if self .use_debug_mode :
457
421
if numerical_params is not None :
458
- assert numerical_params .ndim == 2
459
- assert (
460
- self .n_numerical_params
461
- <= numerical_params .size (0 )
462
- <= constants .NEUTONE_GEN_N_NUMERICAL_PARAMS
463
- )
464
- assert numerical_params .size (1 ) == in_n
422
+ assert numerical_params .shape == (self .n_numerical_params , in_n )
465
423
if not self .is_one_shot_model () and self .get_native_buffer_sizes ():
466
424
assert (
467
425
in_n in self .get_native_buffer_sizes ()
0 commit comments