Skip to content

Commit 4a21213

Browse files
committed
Fix scale of the pred-prey data
1 parent 7aeb4f6 commit 4a21213

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

neuralprocesses/data/predprey.py

+8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
__all__ = ["PredPreyGenerator", "PredPreyRealGenerator"]
1212

1313

14+
# Determine the true scale of the hare-lynx data set. The simulator will be tuned to
15+
# this.
16+
_true_scale = np.mean(np.array(load()))
17+
18+
1419
def _predprey_step(state, x_y, t, dt, *, alpha, beta, delta, gamma, sigma):
1520
x = x_y[..., 0]
1621
y = x_y[..., 1]
@@ -84,6 +89,9 @@ def collect(t_target, remainder=False):
8489
t = B.to_active_device(B.cast(dtype, B.stack(*t)))
8590
traj = B.stack(*traj, axis=-1)
8691

92+
# Fix the scale of the trajectory.
93+
traj = traj / B.mean(traj, axis=(1, 2), squeeze=False) * _true_scale
94+
8795
# Undo the sorting.
8896
t = B.take(t, inv_perm)
8997
traj = B.take(traj, inv_perm, axis=-1)

0 commit comments

Comments
 (0)