Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def search_indices_vertical_s(


def _search_indices_rectilinear(
field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False
field: Field, time: float, z: float, y: float, x: float, ti: int, particle=None, search2D=False
):
grid = field.grid

Expand Down Expand Up @@ -223,7 +223,7 @@ def _search_indices_rectilinear(
return (zeta, eta, xsi, zi, yi, xi)


def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False):
if particle:
zi, yi, xi = field.unravel_index(particle.ei)
else:
Expand Down
2 changes: 1 addition & 1 deletion parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])

zeta, eta, xsi, zi, yi, xi = fieldset.U._search_indices(
-1, particle.depth, particle.lat, particle.lon, particle=particle
time, particle.depth, particle.lat, particle.lon, ti, particle=particle

Check warning on line 188 in parcels/application_kernels/advection.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/advection.py#L188

Added line #L188 was not covered by tests
)
if withW:
if abs(xsi - 1) < tol:
Expand Down
22 changes: 11 additions & 11 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,15 +877,15 @@ def cell_areas(self):
"""
return _calc_cell_areas(self.grid)

def _search_indices(self, time, z, y, x, ti=-1, particle=None, search2D=False):
def _search_indices(self, time, z, y, x, ti, particle=None, search2D=False):
if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
return _search_indices_rectilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)
else:
return _search_indices_curvilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)

def _interpolator2D(self, ti, z, y, x, particle=None):
def _interpolator2D(self, time, z, y, x, ti, particle=None):
"""Impelement 2D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
(_, eta, xsi, _, yi, xi) = self._search_indices(-1, z, y, x, particle=particle)
(_, eta, xsi, _, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
ctx = InterpolationContext2D(self.data, eta, xsi, ti, yi, xi)

try:
Expand All @@ -899,7 +899,7 @@ def _interpolator2D(self, ti, z, y, x, particle=None):
raise RuntimeError(self.interp_method + " is not implemented for 2D grids")
return f(ctx)

def _interpolator3D(self, ti, z, y, x, time, particle=None):
def _interpolator3D(self, time, z, y, x, ti, particle=None):
"""Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
(zeta, eta, xsi, zi, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype)
Expand Down Expand Up @@ -931,13 +931,13 @@ def temporal_interpolate_fullfield(self, ti, time):
f1 = self.data[ti + 1, :]
return f0 + (f1 - f0) * ((time - t0) / (t1 - t0))

def _spatial_interpolation(self, ti, z, y, x, time, particle=None):
"""Interpolate horizontal field values using a SciPy interpolator."""
def _spatial_interpolation(self, time, z, y, x, ti, particle=None):
"""Interpolate spatial field values."""
try:
if self.grid.zdim == 1:
val = self._interpolator2D(ti, z, y, x, particle=particle)
val = self._interpolator2D(time, z, y, x, ti, particle=particle)
else:
val = self._interpolator3D(ti, z, y, x, time, particle=particle)
val = self._interpolator3D(time, z, y, x, ti, particle=particle)

if np.isnan(val):
# Detect Out-of-bounds sampling and raise exception
Expand Down Expand Up @@ -1001,16 +1001,16 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
if self.gridindexingtype == "croco" and self not in [self.fieldset.H, self.fieldset.Zeta]:
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
f0 = self._spatial_interpolation(ti, z, y, x, time, particle=particle)
f1 = self._spatial_interpolation(ti + 1, z, y, x, time, particle=particle)
f0 = self._spatial_interpolation(time, z, y, x, ti, particle=particle)
f1 = self._spatial_interpolation(time, z, y, x, ti + 1, particle=particle)
t0 = self.grid.time[ti]
t1 = self.grid.time[ti + 1]
value = f0 + (f1 - f0) * ((time - t0) / (t1 - t0))
else:
# Skip temporal interpolation if time is outside
# of the defined time range or if we have hit an
# exact value in the time array.
value = self._spatial_interpolation(ti, z, y, x, self.grid.time[ti], particle=particle)
value = self._spatial_interpolation(self.grid.time[ti], z, y, x, ti, particle=particle)

if applyConversion:
return self.units.to_target(value, z, y, x)
Expand Down
4 changes: 2 additions & 2 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
self._ncount = len(lon)

for v in self.ptype.variables:
if v.name in ["ei", "ti"]:
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype)
if v.name == "ei":
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype) # TODO len(lon) can be self._ncount?
else:
self._data[v.name] = np.empty(self._ncount, dtype=v.dtype)

Expand Down
3 changes: 0 additions & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,11 @@ def ArrayClass_init(self, *args, **kwargs):
self.ngrids = type(self).ngrids.initial
if self.ngrids >= 0:
self.ei = np.zeros(self.ngrids, dtype=np.int32)
self.ti = -1 * np.ones(self.ngrids, dtype=np.int32)
super(type(self), self).__init__(*args, **kwargs)

array_class_vdict = {
"ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1),
"ei": Variable("ei", dtype=np.int32, to_write=False),
"ti": Variable("ti", dtype=np.int32, to_write=False, initial=-1),
"__init__": ArrayClass_init,
}
array_class = type(class_name, (pclass,), array_class_vdict)
Expand Down Expand Up @@ -719,7 +717,6 @@ def from_particlefile(
v.name
not in [
"ei",
"ti",
"dt",
"depth",
"id",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def test_fieldset_write(tmp_zarrfile):
def UpdateU(particle, fieldset, time): # pragma: no cover
tmp1, tmp2 = fieldset.UV[particle]
_, yi, xi = fieldset.U.unravel_index(particle.ei)
fieldset.U.data[particle.ti, yi, xi] += 1
fieldset.U.data[0, yi, xi] += 1
fieldset.U.grid.time[0] = time

pset = ParticleSet(fieldset, pclass=Particle, lon=5, lat=5)
Expand Down
Loading