Skip to content

Commit 7dc58c2

Browse files
authored
[BUGFIX, BREAKING] Make activation base class abstract, fix PReLU implementation (#223)
* Make activation apply method pure virtual instead of no-op default * Fix bugs * Refactor to throw std::invalid_argument in debug mode, add tests
1 parent 95c7aa6 commit 7dc58c2

File tree

3 files changed

+89
-12
lines changed

3 files changed

+89
-12
lines changed

NAM/activations.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <cassert>
44
#include <cmath> // expf
5+
#include <iostream> // std::cerr (kept for potential debug use)
6+
#include <stdexcept> // std::invalid_argument
57
#include <functional>
68
#include <memory>
79
#include <optional>
@@ -150,7 +152,7 @@ class Activation
150152
{
151153
apply(block.data(), block.rows() * block.cols());
152154
}
153-
virtual void apply(float* data, long size) {}
155+
virtual void apply(float* data, long size) = 0;
154156

155157
static Ptr get_activation(const std::string name);
156158
static Ptr get_activation(const ActivationConfig& config);
@@ -165,13 +167,13 @@ class Activation
165167
static std::unordered_map<std::string, Ptr> _activations;
166168
};
167169

168-
// identity function activation
170+
// identity function activation--"do nothing"
169171
class ActivationIdentity : public nam::activations::Activation
170172
{
171173
public:
172174
ActivationIdentity() = default;
173175
~ActivationIdentity() = default;
174-
// Inherit the default apply methods which do nothing
176+
virtual void apply(float* data, long size) override {};
175177
};
176178

177179
class ActivationTanh : public Activation
@@ -276,6 +278,24 @@ class ActivationPReLU : public Activation
276278
}
277279
ActivationPReLU(std::vector<float> ns) { negative_slopes = ns; }
278280

281+
void apply(float* data, long size) override
282+
{
283+
// Assume column-major (this is brittle)
284+
#ifndef NDEBUG
285+
if (size % negative_slopes.size() != 0)
286+
{
287+
throw std::invalid_argument("PReLU.apply(*data, size) was given an array of size " + std::to_string(size)
288+
+ " but the activation has " + std::to_string(negative_slopes.size())
289+
+ " channels, which doesn't divide evenly.");
290+
}
291+
#endif
292+
for (long pos = 0; pos < size; pos++)
293+
{
294+
const float negative_slope = negative_slopes[pos % negative_slopes.size()];
295+
data[pos] = leaky_relu(data[pos], negative_slope);
296+
}
297+
}
298+
279299
void apply(Eigen::MatrixXf& matrix) override
280300
{
281301
// Matrix is organized as (channels, time_steps)
@@ -285,7 +305,14 @@ class ActivationPReLU : public Activation
285305
std::vector<float> slopes_for_channels = negative_slopes;
286306

287307
// Fail loudly if input has more channels than activation
288-
assert(actual_channels == negative_slopes.size());
308+
#ifndef NDEBUG
309+
if (actual_channels != negative_slopes.size())
310+
{
311+
throw std::invalid_argument("PReLU: Received " + std::to_string(actual_channels)
312+
+ " channels, but activation has " + std::to_string(negative_slopes.size())
313+
+ " channels");
314+
}
315+
#endif
289316

290317
// Apply each negative slope to its corresponding channel
291318
for (unsigned long channel = 0; channel < actual_channels; channel++)

tools/run_tests.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ int main()
4848

4949
test_activations::TestPReLU::test_core_function();
5050
test_activations::TestPReLU::test_per_channel_behavior();
51-
// This is enforced by an assert so it doesn't need to be tested
52-
// test_activations::TestPReLU::test_wrong_number_of_channels();
51+
test_activations::TestPReLU::test_wrong_number_of_channels_matrix();
52+
test_activations::TestPReLU::test_wrong_size_array();
53+
test_activations::TestPReLU::test_valid_array_size();
5354

5455
// Typed ActivationConfig tests
5556
test_activations::TestTypedActivationConfig::test_simple_config();

tools/test/test_activations.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,10 @@ class TestPReLU
220220
assert(fabs(data(1, 2) - 0.0f) < 1e-6); // 0.0 (unchanged)
221221
}
222222

223-
static void test_wrong_number_of_channels()
223+
static void test_wrong_number_of_channels_matrix()
224224
{
225-
// Test that we fail when we have more channels than slopes
225+
// Test that we fail when matrix has more channels than slopes
226+
// Note: This validation only runs in debug builds (#ifndef NDEBUG)
226227
Eigen::MatrixXf data(3, 2); // 3 channels, 2 time steps
227228

228229
// Initialize with test data
@@ -232,21 +233,69 @@ class TestPReLU
232233
std::vector<float> slopes = {0.01f, 0.05f};
233234
nam::activations::ActivationPReLU prelu(slopes);
234235

235-
// Apply the activation
236+
#ifndef NDEBUG
237+
// In debug mode, this should throw std::invalid_argument
236238
bool caught = false;
237239
try
238240
{
239241
prelu.apply(data);
240242
}
241-
catch (const std::runtime_error& e)
243+
catch (const std::invalid_argument& e)
242244
{
243245
caught = true;
244246
}
245-
catch (...)
247+
assert(caught && "Expected std::invalid_argument for channel count mismatch");
248+
#endif
249+
}
250+
251+
static void test_wrong_size_array()
252+
{
253+
// Test that we fail when array size doesn't divide evenly by channel count
254+
// Note: This validation only runs in debug builds (#ifndef NDEBUG)
255+
256+
// Create PReLU with 2 channels
257+
std::vector<float> slopes = {0.01f, 0.05f};
258+
nam::activations::ActivationPReLU prelu(slopes);
259+
260+
// Array of size 5 doesn't divide evenly by 2 channels
261+
std::vector<float> data = {-1.0f, -2.0f, 0.5f, 1.0f, -0.5f};
262+
263+
#ifndef NDEBUG
264+
// In debug mode, this should throw std::invalid_argument
265+
bool caught = false;
266+
try
267+
{
268+
prelu.apply(data.data(), (long)data.size());
269+
}
270+
catch (const std::invalid_argument& e)
246271
{
272+
caught = true;
247273
}
274+
assert(caught && "Expected std::invalid_argument for array size mismatch");
275+
#endif
276+
}
277+
278+
static void test_valid_array_size()
279+
{
280+
// Test that valid array sizes work correctly
281+
282+
// Create PReLU with 2 channels
283+
std::vector<float> slopes = {0.1f, 0.2f};
284+
nam::activations::ActivationPReLU prelu(slopes);
285+
286+
// Array of size 6 divides evenly by 2 channels (3 time steps per channel)
287+
std::vector<float> data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
288+
289+
// Should not throw
290+
prelu.apply(data.data(), (long)data.size());
248291

249-
assert(caught);
292+
// Verify results: alternating between slope 0.1 and 0.2
293+
assert(fabs(data[0] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
294+
assert(fabs(data[1] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
295+
assert(fabs(data[2] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
296+
assert(fabs(data[3] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
297+
assert(fabs(data[4] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
298+
assert(fabs(data[5] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
250299
}
251300
};
252301

0 commit comments

Comments
 (0)