Skip to content

Commit 177f133

Browse files
committed
Fix Nan gradients in Force model with padded_disjoint representaiton
1 parent 58e43b5 commit 177f133

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

kgcnn/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88

99
# Behaviour for backend functions.
1010
__safe_scatter_max_min_to_zero__ = True
11+
12+
# Geometry
13+
__geom_euclidean_norm_add_eps__ = False
14+
__geom_euclidean_norm_no_nan__ = True # Only used for inverse norm.

kgcnn/layers/casting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
def _pad_left(t):
10-
# return ops.concatenate([ops.zeros_like(t[:1]), t], axis=0)
1110
return ops.pad(t, [[1, 0]] + [[0, 0] for _ in range(len(ops.shape(t)) - 1)])
1211

1312

kgcnn/layers/geom.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from kgcnn.layers.polynom import SphericalBesselJnExplicit, SphericalHarmonicsYl
1111
from kgcnn.ops.axis import get_positive_axis
1212
from kgcnn.ops.core import cross as kgcnn_cross
13+
from kgcnn import __geom_euclidean_norm_add_eps__ as global_geom_euclidean_norm_add_eps
14+
from kgcnn import __geom_euclidean_norm_no_nan__ as global_geom_euclidean_norm_no_nan
1315

1416

1517
class NodePosition(Layer):
@@ -142,8 +144,11 @@ class EuclideanNorm(Layer):
142144
with :obj:`invert_norm` layer arguments.
143145
"""
144146

145-
def __init__(self, axis: int = -1, keepdims: bool = False, invert_norm: bool = False, add_eps: bool = False,
146-
no_nan: bool = True, square_norm: bool = False, **kwargs):
147+
def __init__(self, axis: int = -1, keepdims: bool = False,
148+
invert_norm: bool = False,
149+
add_eps: bool = global_geom_euclidean_norm_add_eps,
150+
no_nan: bool = global_geom_euclidean_norm_no_nan,
151+
square_norm: bool = False, **kwargs):
147152
"""Initialize layer.
148153
149154
Args:
@@ -177,7 +182,7 @@ def compute_output_shape(self, input_shape):
177182

178183
@staticmethod
179184
def _compute_euclidean_norm(inputs, axis: int = -1, keepdims: bool = False, invert_norm: bool = False,
180-
add_eps: bool = False, no_nan: bool = True, square_norm: bool = False):
185+
add_eps: bool = False, no_nan: bool = False, square_norm: bool = False):
181186
"""Function to compute euclidean norm for inputs.
182187
183188
Args:
@@ -306,7 +311,10 @@ class NodeDistanceEuclidean(Layer):
306311
the output of :obj:`NodePosition`.
307312
"""
308313

309-
def __init__(self, add_eps: bool = False, no_nan: bool = True, **kwargs):
314+
def __init__(self,
315+
add_eps: bool = global_geom_euclidean_norm_add_eps,
316+
no_nan: bool = global_geom_euclidean_norm_no_nan,
317+
**kwargs):
310318
r"""Initialize layer instance of :obj:`NodeDistanceEuclidean`. """
311319
super(NodeDistanceEuclidean, self).__init__(**kwargs)
312320
self.layer_subtract = Subtract()
@@ -354,7 +362,9 @@ class EdgeDirectionNormalized(Layer):
354362
As the first index defines the incoming edge.
355363
"""
356364

357-
def __init__(self, add_eps: bool = False, no_nan: bool = True, **kwargs):
365+
def __init__(self, add_eps: bool = global_geom_euclidean_norm_add_eps,
366+
no_nan: bool = global_geom_euclidean_norm_no_nan,
367+
**kwargs):
358368
"""Initialize layer."""
359369
super(EdgeDirectionNormalized, self).__init__(**kwargs)
360370
self.layer_subtract = Subtract()

training/train_force.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import argparse
55
import keras as ks
66
from datetime import timedelta
7+
import kgcnn
78
import kgcnn.training.schedule
89
import kgcnn.training.scheduler
910
from kgcnn.data.utils import save_pickle_file
@@ -18,6 +19,9 @@
1819
from kgcnn.metrics.metrics import ScaledMeanAbsoluteError, ScaledForceMeanAbsoluteError
1920
from kgcnn.data.transform.scaler.force import EnergyForceExtensiveLabelScaler
2021

22+
# For force gradients
23+
kgcnn.__geom_euclidean_norm_add_eps__ = True
24+
2125
# Input arguments from command line.
2226
parser = argparse.ArgumentParser(description='Train a GNN on an Energy-Force Dataset.')
2327
parser.add_argument("--hyper", required=False, help="Filepath to hyper-parameter config file (.py or .json).",

0 commit comments

Comments
 (0)