Skip to content

Commit 7c39dbf

Browse files
committed
[cm] Adding generative clipper example
1 parent cb63c13 commit 7c39dbf

9 files changed

+119
-0
lines changed
File renamed without changes.
File renamed without changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import logging
2+
import os
3+
import pathlib
4+
from argparse import ArgumentParser
5+
from typing import Dict, List
6+
7+
import torch as tr
8+
import torch.nn as nn
9+
from torch import Tensor
10+
11+
from neutone_sdk import NeutoneParameter, ContinuousNeutoneParameter
12+
from neutone_sdk.non_realtime_wrapper import NonRealtimeBase
13+
14+
logging.basicConfig()
15+
log = logging.getLogger(__name__)
16+
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))
17+
18+
19+
class ClipperModel(nn.Module):
20+
def forward(self,
21+
x: Tensor,
22+
min_val: Tensor,
23+
max_val: Tensor,
24+
gain: Tensor) -> Tensor:
25+
tr.neg(min_val, out=min_val)
26+
tr.mul(gain, min_val, out=min_val)
27+
tr.mul(gain, max_val, out=max_val)
28+
tr.clip(x, min=min_val, max=max_val, out=x)
29+
return x
30+
31+
32+
class NonRealtimeClipperModelWrapper(NonRealtimeBase):
33+
def get_model_name(self) -> str:
34+
return "clipper"
35+
36+
def get_model_authors(self) -> List[str]:
37+
return ["Christopher Mitcheltree"]
38+
39+
def get_model_short_description(self) -> str:
40+
return "Audio clipper."
41+
42+
def get_model_long_description(self) -> str:
43+
return "Clips the input audio between -1 and 1."
44+
45+
def get_technical_description(self) -> str:
46+
return "Clips the input audio between -1 and 1."
47+
48+
def get_technical_links(self) -> Dict[str, str]:
49+
return {
50+
"Code": "https://github.com/QosmoInc/neutone_sdk/blob/main/examples/neutone_gen/example_clipper_gen.py"
51+
}
52+
53+
def get_tags(self) -> List[str]:
54+
return ["clipper"]
55+
56+
def get_model_version(self) -> str:
57+
return "1.0.0"
58+
59+
def is_experimental(self) -> bool:
60+
return False
61+
62+
def get_neutone_parameters(self) -> List[NeutoneParameter]:
63+
return [
64+
ContinuousNeutoneParameter("min", "min clip threshold", default_value=0.15),
65+
ContinuousNeutoneParameter("max", "max clip threshold", default_value=0.15),
66+
ContinuousNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
67+
]
68+
69+
@tr.jit.export
70+
def get_audio_in_channels(self) -> List[int]:
71+
return [2]
72+
73+
@tr.jit.export
74+
def get_audio_out_channels(self) -> List[int]:
75+
return [2]
76+
77+
@tr.jit.export
78+
def get_native_sample_rates(self) -> List[int]:
79+
return [] # Supports all sample rates
80+
81+
@tr.jit.export
82+
def get_native_buffer_sizes(self) -> List[int]:
83+
return [] # Supports all buffer sizes
84+
85+
@tr.jit.export
86+
def is_one_shot_model(self) -> bool:
87+
return False
88+
89+
def aggregate_continuous_params(self, cont_params: Tensor) -> Tensor:
90+
return cont_params # We want sample-level control, so no aggregation
91+
92+
def do_forward_pass(self,
93+
curr_block_idx: int,
94+
audio_in: List[Tensor],
95+
knob_params: Dict[str, Tensor],
96+
text_params: List[str]) -> List[Tensor]:
97+
min_val, max_val, gain = (knob_params["min"],
98+
knob_params["max"],
99+
knob_params["gain"])
100+
audio_out = []
101+
for x in audio_in:
102+
x = self.model.forward(x, min_val, max_val, gain)
103+
audio_out.append(x)
104+
return audio_out
105+
106+
107+
if __name__ == "__main__":
108+
parser = ArgumentParser()
109+
parser.add_argument("-o", "--output", default="export_model")
110+
args = parser.parse_args()
111+
root_dir = pathlib.Path(args.output)
112+
113+
model = ClipperModel()
114+
wrapper = NonRealtimeClipperModelWrapper(model)
115+
116+
# TODO(cm): write export method for nonrealtime models
117+
wrapper.forward(0, [tr.rand(2, 2048)])
118+
ts = tr.jit.script(wrapper)
119+
ts.forward(0, [tr.rand(2, 2048)])

0 commit comments

Comments
 (0)