7
7
8
8
class SystemEquation (Equation ):
9
9
10
- def __init__ (self , list_equation , reduction = "mean" ):
10
+ def __init__ (self , list_equation , reduction = None ):
11
11
"""
12
12
System of Equation class for specifing any system
13
13
of equations in PINA.
@@ -19,14 +19,13 @@ def __init__(self, list_equation, reduction="mean"):
19
19
:param Callable equation: A ``torch`` callable equation to
20
20
evaluate the residual
21
21
:param str reduction: Specifies the reduction to apply to the output:
22
- ``none`` | ``mean`` | ``sum`` | `` callable``. ``none`` : no reduction
23
- will be applied, ``mean``: the sum of the output will be divided
22
+ None | ``mean`` | ``sum`` | callable. None : no reduction
23
+ will be applied, ``mean``: the output sum will be divided
24
24
by the number of elements in the output, ``sum``: the output will
25
- be summed. `` callable`` a callable function to perform reduction,
26
- no checks guaranteed. Default: ``mean`` .
25
+ be summed. * callable* is a callable function to perform reduction,
26
+ no checks guaranteed. Default: None .
27
27
"""
28
28
check_consistency ([list_equation ], list )
29
- check_consistency (reduction , str )
30
29
31
30
# equations definition
32
31
self .equations = []
@@ -38,7 +37,7 @@ def __init__(self, list_equation, reduction="mean"):
38
37
self .reduction = torch .mean
39
38
elif reduction == "sum" :
40
39
self .reduction = torch .sum
41
- elif (reduction == "none" ) or callable (reduction ):
40
+ elif (reduction == None ) or callable (reduction ):
42
41
self .reduction = reduction
43
42
else :
44
43
raise NotImplementedError (
@@ -72,7 +71,7 @@ def residual(self, input_, output_, params_=None):
72
71
]
73
72
)
74
73
75
- if self .reduction == "none" :
74
+ if self .reduction is None :
76
75
return residual
77
76
78
77
return self .reduction (residual , dim = - 1 )
0 commit comments