Skip to content

Commit 0dc6bcf

Browse files
Merge pull request #369 from apphp/SAM-4-implementation-of-gelu-function
Sam 4 implementation of GeLU function
2 parents 5a15240 + 3d113a5 commit 0dc6bcf

File tree

7 files changed

+303
-108
lines changed

7 files changed

+303
-108
lines changed

phpstan.neon

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ parameters:
1313
- src/Backends/Amp.php
1414
- src/Backends/Swoole.php
1515
- tests/Backends/SwooleTest.php
16+
ignoreErrors:
17+
# ------------------------------------
18+
# Ignore errors that are caused by NumPower
19+
# ------------------------------------
20+
- message: '#^Binary operation "\*\*" between NDArray and .* results in an error.#'
21+
paths:
22+
- src/NeuralNet/ActivationFunctions/*

src/NeuralNet/ActivationFunctions/Base/Contracts/SingleBufferDerivative.php renamed to src/NeuralNet/ActivationFunctions/Base/Contracts/IBufferDerivative.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* @author Andrew DalPino
1515
* @author Aleksei Nechaev <[email protected]>
1616
*/
17-
interface SingleBufferDerivative extends Derivative
17+
interface IBufferDerivative extends Derivative
1818
{
1919
/**
2020
* Calculate the derivative of the single parameter.

src/NeuralNet/ActivationFunctions/ELU/ELU.php

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use NumPower;
88
use NDArray;
99
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10-
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\SingleBufferDerivative;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
1111
use Rubix\ML\NeuralNet\ActivationFunctions\ELU\Exceptions\InvalidAlphaException;
1212

1313
/**
@@ -26,7 +26,7 @@
2626
* @author Aleksei Nechaev <[email protected]>
2727
* @author Samuel Akopyan <[email protected]>
2828
*/
29-
class ELU implements ActivationFunction, SingleBufferDerivative
29+
class ELU implements ActivationFunction, IBufferDerivative
3030
{
3131
/**
3232
* Class constructor.
@@ -57,17 +57,17 @@ public function __construct(protected float $alpha = 1.0)
5757
public function activate(NDArray $input) : NDArray
5858
{
5959
// Calculate positive part: x for x > 0
60-
$positiveActivation = NumPower::maximum(a: $input, b: 0);
60+
$positiveActivation = NumPower::maximum($input, 0);
6161

6262
// Calculate negative part: alpha * (e^x - 1) for x <= 0
63-
$negativeMask = NumPower::minimum(a: $input, b: 0);
63+
$negativeMask = NumPower::minimum($input, 0);
6464
$negativeActivation = NumPower::multiply(
65-
a: NumPower::expm1($negativeMask),
66-
b: $this->alpha
65+
NumPower::expm1($negativeMask),
66+
$this->alpha
6767
);
6868

6969
// Combine both parts
70-
return NumPower::add(a: $positiveActivation, b: $negativeActivation);
70+
return NumPower::add($positiveActivation, $negativeActivation);
7171
}
7272

7373
/**
@@ -82,17 +82,17 @@ public function activate(NDArray $input) : NDArray
8282
public function differentiate(NDArray $x) : NDArray
8383
{
8484
// For x > 0: 1
85-
$positivePart = NumPower::greater(a: $x, b: 0);
85+
$positivePart = NumPower::greater($x, 0);
8686

8787
// For x <= 0: α * e^x
88-
$negativeMask = NumPower::lessEqual(a: $x, b: 0);
88+
$negativeMask = NumPower::lessEqual($x, 0);
8989
$negativePart = NumPower::multiply(
90-
a: NumPower::multiply(a: $negativeMask, b: NumPower::exp($x)),
91-
b: $this->alpha
90+
NumPower::multiply($negativeMask, NumPower::exp($x)),
91+
$this->alpha
9292
);
9393

9494
// Combine both parts
95-
return NumPower::add(a: $positivePart, b: $negativePart);
95+
return NumPower::add($positivePart, $negativePart);
9696
}
9797

9898
/**

src/NeuralNet/ActivationFunctions/GeLU/GELU.php

Lines changed: 0 additions & 73 deletions
This file was deleted.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Rubix\ML\NeuralNet\ActivationFunctions\GeLU;
6+
7+
use NumPower;
8+
use NDArray;
9+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
10+
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
11+
12+
/**
13+
* GeLU
14+
*
15+
* Gaussian Error Linear Units (GeLUs) are rectifiers that are gated by the magnitude of their input rather
16+
* than the sign of their input as with ReLU variants. Their output can be interpreted as the expected value
17+
* of a neuron with random dropout regularization applied.
18+
*
19+
* References:
20+
* [1] D. Hendrycks et al. (2018). Gaussian Error Linear Units (GeLUs).
21+
*
22+
* @category Machine Learning
23+
* @package Rubix/ML
24+
* @author Andrew DalPino
25+
* @author Aleksei Nechaev <[email protected]>
26+
* @author Samuel Akopyan <[email protected]>
27+
*/
28+
class GeLU implements ActivationFunction, IBufferDerivative
29+
{
30+
/**
31+
* The square root of two over pi constant sqrt(2/π).
32+
*
33+
* @var float
34+
*/
35+
protected const ALPHA = 0.7978845608;
36+
37+
/**
38+
* Gaussian error function approximation term.
39+
*
40+
* @var float
41+
*/
42+
protected const BETA = 0.044715;
43+
44+
/**
45+
* Apply the GeLU activation function to the input.
46+
*
47+
* f(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
48+
*
49+
* @param NDArray $input The input values
50+
* @return NDArray The activated values
51+
*/
52+
public function activate(NDArray $input) : NDArray
53+
{
54+
// Calculate x^3
55+
$cubed = $input ** 3;
56+
57+
// Calculate inner term: x + BETA * x^3
58+
$innerTerm = NumPower::add(
59+
$input,
60+
NumPower::multiply(self::BETA, $cubed)
61+
);
62+
63+
// Apply tanh(ALPHA * innerTerm)
64+
$tanhTerm = NumPower::tanh(
65+
NumPower::multiply(self::ALPHA, $innerTerm)
66+
);
67+
68+
// Calculate 1 + tanhTerm
69+
$onePlusTanh = NumPower::add(1.0, $tanhTerm);
70+
71+
// Calculate 0.5 * x * (1 + tanhTerm)
72+
return NumPower::multiply(
73+
0.5,
74+
NumPower::multiply($input, $onePlusTanh)
75+
);
76+
}
77+
78+
/**
79+
* Calculate the derivative of the activation function.
80+
*
81+
* The derivative of GeLU is:
82+
* f'(x) = 0.5 * (1 + tanh(α * (x + β * x^3))) +
83+
* 0.5 * x * sech^2(α * (x + β * x^3)) * α * (1 + 3β * x^2)
84+
*
85+
* Where:
86+
* - α = sqrt(2/π) ≈ 0.7978845608
87+
* - β = 0.044715
88+
* - sech^2(z) = (1/cosh(z))^2
89+
*
90+
* @param NDArray $x Output matrix
91+
* @return NDArray Derivative matrix
92+
*/
93+
public function differentiate(NDArray $x) : NDArray
94+
{
95+
// Calculate x^3
96+
$cubed = $x ** 3;
97+
98+
// Calculate inner term: ALPHA * (x + BETA * x^3)
99+
$innerTerm = NumPower::multiply(
100+
self::ALPHA,
101+
NumPower::add(
102+
$x,
103+
NumPower::multiply(self::BETA, $cubed)
104+
)
105+
);
106+
107+
// Calculate cosh and sech^2
108+
$cosh = NumPower::cosh($innerTerm);
109+
$sech2 = NumPower::pow(
110+
NumPower::divide(1.0, $cosh),
111+
2
112+
);
113+
114+
// Calculate 0.5 * (1 + tanh(innerTerm))
115+
$firstTerm = NumPower::multiply(
116+
0.5,
117+
NumPower::add(1.0, NumPower::tanh($innerTerm))
118+
);
119+
120+
// Calculate 0.5 * x * sech^2 * ALPHA * (1 + 3 * BETA * x^2)
121+
$secondTerm = NumPower::multiply(
122+
NumPower::multiply(
123+
NumPower::multiply(
124+
0.5 * self::ALPHA,
125+
$x
126+
),
127+
$sech2
128+
),
129+
NumPower::add(
130+
1.0,
131+
NumPower::multiply(
132+
3.0 * self::BETA,
133+
NumPower::pow($x, 2)
134+
)
135+
)
136+
);
137+
138+
// Combine terms
139+
return NumPower::add($firstTerm, $secondTerm);
140+
}
141+
142+
/**
143+
* Return the string representation of the activation function.
144+
*
145+
* @return string String representation
146+
*/
147+
public function __toString() : string
148+
{
149+
return 'GeLU';
150+
}
151+
}

tests/NeuralNet/ActivationFunctions/ELU/ELUTest.php

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@
77
use PHPUnit\Framework\Attributes\CoversClass;
88
use PHPUnit\Framework\Attributes\DataProvider;
99
use PHPUnit\Framework\Attributes\Group;
10+
use PHPUnit\Framework\Attributes\Test;
11+
use PHPUnit\Framework\Attributes\TestDox;
1012
use NumPower;
1113
use NDArray;
1214
use Rubix\ML\NeuralNet\ActivationFunctions\ELU\ELU;
1315
use PHPUnit\Framework\TestCase;
1416
use Generator;
1517
use Rubix\ML\NeuralNet\ActivationFunctions\ELU\Exceptions\InvalidAlphaException;
1618

17-
/**
18-
* @group ActivationFunctions
19-
* @covers \Rubix\ML\NeuralNet\ActivationFunctions\ELU\ELU
20-
*/
2119
#[Group('ActivationFunctions')]
2220
#[CoversClass(ELU::class)]
2321
class ELUTest extends TestCase
@@ -93,9 +91,8 @@ protected function setUp() : void
9391
$this->activationFn = new ELU(1.0);
9492
}
9593

96-
/**
97-
* @test
98-
*/
94+
#[Test]
95+
#[TestDox('Can be constructed with valid alpha parameter')]
9996
public function testConstructorWithValidAlpha() : void
10097
{
10198
$activationFn = new ELU(2.0);
@@ -104,28 +101,24 @@ public function testConstructorWithValidAlpha() : void
104101
static::assertEquals('ELU (alpha: 2)', (string) $activationFn);
105102
}
106103

107-
/**
108-
* @test
109-
*/
104+
#[Test]
105+
#[TestDox('Throws exception when constructed with invalid alpha parameter')]
110106
public function testConstructorWithInvalidAlpha() : void
111107
{
112108
$this->expectException(InvalidAlphaException::class);
113109

114110
new ELU(-346);
115111
}
116112

117-
/**
118-
* @test
119-
*/
113+
#[Test]
114+
#[TestDox('Can be cast to a string')]
120115
public function testToString() : void
121116
{
122117
static::assertEquals('ELU (alpha: 1)', (string) $this->activationFn);
123118
}
124119

125-
/**
126-
* @param NDArray $input
127-
* @param list<list<float>> $expected
128-
*/
120+
#[Test]
121+
#[TestDox('Correctly activates the input')]
129122
#[DataProvider('computeProvider')]
130123
public function testActivate(NDArray $input, array $expected) : void
131124
{
@@ -134,12 +127,10 @@ public function testActivate(NDArray $input, array $expected) : void
134127
static::assertEquals($expected, $activations);
135128
}
136129

137-
/**
138-
* @param NDArray $input
139-
* @param list<list<float>> $expected
140-
*/
130+
#[Test]
131+
#[TestDox('Correctly differentiates the input')]
141132
#[DataProvider('differentiateProvider')]
142-
public function testDifferentiate1(NDArray $input, array $expected) : void
133+
public function testDifferentiate(NDArray $input, array $expected) : void
143134
{
144135
$derivatives = $this->activationFn->differentiate($input)->toArray();
145136

0 commit comments

Comments
 (0)