Skip to content

Commit a2c3235

Browse files
committed
[BUGFIX] Support no head key in WaveNet config
1 parent 3377e11 commit a2c3235

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

NAM/wavenet.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
820820
input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
821821
activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params));
822822
}
823-
const bool with_head = !config["head"].is_null();
823+
const bool with_head = config.find("head") != config.end() && !config["head"].is_null();
824824
const float head_scale = config["head_scale"];
825825

826826
if (layer_array_params.empty())

tools/run_tests.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "test/test_wavenet/test_condition_processing.cpp"
2020
#include "test/test_wavenet/test_head1x1.cpp"
2121
#include "test/test_wavenet/test_layer1x1.cpp"
22+
#include "test/test_wavenet/test_factory.cpp"
2223
#include "test/test_gating_activations.cpp"
2324
#include "test/test_wavenet_gating_compatibility.cpp"
2425
#include "test/test_blending_detailed.cpp"
@@ -169,6 +170,7 @@ int main()
169170
test_wavenet::test_layer1x1::test_layer1x1_post_film_inactive_with_layer1x1_inactive();
170171
test_wavenet::test_layer1x1::test_layer1x1_gated();
171172
test_wavenet::test_layer1x1::test_layer1x1_groups();
173+
test_wavenet::test_factory::test_factory_without_head_key();
172174
test_wavenet::test_allocation_tracking_pass();
173175
test_wavenet::test_allocation_tracking_fail();
174176
test_wavenet::test_conv1d_process_realtime_safe();
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Tests for WaveNet Factory
2+
3+
#include <cassert>
4+
#include <cmath>
5+
#include <memory>
6+
#include <vector>
7+
8+
#include "json.hpp"
9+
10+
#include "NAM/get_dsp.h"
11+
#include "NAM/wavenet.h"
12+
13+
namespace test_wavenet
14+
{
15+
namespace test_factory
16+
{
17+
/// Asserts that the model is instantiated correctly when no "head" key is provided.
18+
/// The deprecated "head" key is optional; when absent, with_head should be false.
19+
void test_factory_without_head_key()
20+
{
21+
// Minimal WaveNet config - deliberately omits the "head" key entirely.
22+
// Same structure as wavenet.nam but without "head" in config.
23+
const std::string configStr = R"({
24+
"version": "0.5.4",
25+
"metadata": {},
26+
"architecture": "WaveNet",
27+
"config": {
28+
"layers": [{
29+
"input_size": 1,
30+
"condition_size": 1,
31+
"head_size": 1,
32+
"channels": 1,
33+
"kernel_size": 1,
34+
"dilations": [1],
35+
"activation": "ReLU",
36+
"gated": false,
37+
"head_bias": false
38+
}],
39+
"head_scale": 1.0
40+
},
41+
"weights": [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0],
42+
"sample_rate": 48000
43+
})";
44+
45+
nlohmann::json j = nlohmann::json::parse(configStr);
46+
47+
// Verify the config does not contain "head" key
48+
assert(j["config"].find("head") == j["config"].end());
49+
50+
// Load model via get_dsp - exercises Factory path
51+
std::unique_ptr<nam::DSP> dsp = nam::get_dsp(j);
52+
assert(dsp != nullptr);
53+
54+
// Process audio to verify model works correctly
55+
const int numFrames = 4;
56+
const int maxBufferSize = 64;
57+
dsp->Reset(48000.0, maxBufferSize);
58+
59+
std::vector<NAM_SAMPLE> input(numFrames, 1.0f);
60+
std::vector<NAM_SAMPLE> output(numFrames, 0.0f);
61+
NAM_SAMPLE* inputPtrs[] = {input.data()};
62+
NAM_SAMPLE* outputPtrs[] = {output.data()};
63+
64+
dsp->process(inputPtrs, outputPtrs, numFrames);
65+
66+
assert(static_cast<int>(output.size()) == numFrames);
67+
for (int i = 0; i < numFrames; i++)
68+
{
69+
assert(std::isfinite(output[i]));
70+
}
71+
}
72+
}; // namespace test_factory
73+
}; // namespace test_wavenet

0 commit comments

Comments
 (0)