Skip to content

Commit 16261c9

Browse files
authored
Change default reduction in SystemEquation (mathLab#317)
* Update system_equation.py * Update test_systemequation.py
1 parent f9316e3 commit 16261c9

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

pina/equation/system_equation.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class SystemEquation(Equation):
99

10-
def __init__(self, list_equation, reduction="mean"):
10+
def __init__(self, list_equation, reduction=None):
1111
"""
1212
System of Equation class for specifing any system
1313
of equations in PINA.
@@ -19,14 +19,13 @@ def __init__(self, list_equation, reduction="mean"):
1919
:param Callable equation: A ``torch`` callable equation to
2020
evaluate the residual
2121
:param str reduction: Specifies the reduction to apply to the output:
22-
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
23-
will be applied, ``mean``: the sum of the output will be divided
22+
None | ``mean`` | ``sum`` | callable. None: no reduction
23+
will be applied, ``mean``: the output sum will be divided
2424
by the number of elements in the output, ``sum``: the output will
25-
be summed. ``callable`` a callable function to perform reduction,
26-
no checks guaranteed. Default: ``mean``.
25+
be summed. *callable* is a callable function to perform reduction,
26+
no checks guaranteed. Default: None.
2727
"""
2828
check_consistency([list_equation], list)
29-
check_consistency(reduction, str)
3029

3130
# equations definition
3231
self.equations = []
@@ -38,7 +37,7 @@ def __init__(self, list_equation, reduction="mean"):
3837
self.reduction = torch.mean
3938
elif reduction == "sum":
4039
self.reduction = torch.sum
41-
elif (reduction == "none") or callable(reduction):
40+
elif (reduction == None) or callable(reduction):
4241
self.reduction = reduction
4342
else:
4443
raise NotImplementedError(
@@ -72,7 +71,7 @@ def residual(self, input_, output_, params_=None):
7271
]
7372
)
7473

75-
if self.reduction == "none":
74+
if self.reduction is None:
7675
return residual
7776

7877
return self.reduction(residual, dim=-1)

tests/test_equations/test_systemequation.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,18 @@ def test_residual():
3939
u = torch.pow(pts, 2)
4040
u.labels = ['u1', 'u2']
4141

42-
eq_1 = SystemEquation([eq1, eq2])
42+
eq_1 = SystemEquation([eq1, eq2], reduction='mean')
4343
res = eq_1.residual(pts, u)
4444
assert res.shape == torch.Size([10])
4545

4646
eq_1 = SystemEquation([eq1, eq2], reduction='sum')
4747
res = eq_1.residual(pts, u)
4848
assert res.shape == torch.Size([10])
4949

50-
eq_1 = SystemEquation([eq1, eq2], reduction='none')
50+
eq_1 = SystemEquation([eq1, eq2], reduction=None)
51+
res = eq_1.residual(pts, u)
52+
assert res.shape == torch.Size([10, 3])
53+
54+
eq_1 = SystemEquation([eq1, eq2])
5155
res = eq_1.residual(pts, u)
5256
assert res.shape == torch.Size([10, 3])

0 commit comments

Comments
 (0)