Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions docs/examples/tutorial_particle_field_interaction.ipynb

Large diffs are not rendered by default.

11 changes: 3 additions & 8 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,14 @@ def _search_indices_rectilinear(
_raise_field_sampling_error(z, y, x)

if particle:
particle.xi[field.igrid] = xi
particle.yi[field.igrid] = yi
particle.zi[field.igrid] = zi
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)

return (zeta, eta, xsi, zi, yi, xi)


def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
if particle:
xi = particle.xi[field.igrid]
yi = particle.yi[field.igrid]
zi, yi, xi = field.unravel_index(particle.ei[field.igrid])
else:
xi = int(field.grid.xdim / 2) - 1
yi = int(field.grid.ydim / 2) - 1
Expand Down Expand Up @@ -310,9 +307,7 @@ def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=Non
_raise_field_sampling_error(z, y, x)

if particle:
particle.xi[field.igrid] = xi
particle.yi[field.igrid] = yi
particle.zi[field.igrid] = zi
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)

return (zeta, eta, xsi, zi, yi, xi)

Expand Down
4 changes: 1 addition & 3 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@
yi += 1
eta = 0

particle.xi[:] = xi
particle.yi[:] = yi
particle.zi[:] = zi
particle.ei[:] = fieldset.U.ravel_index(xi, yi, zi)

Check warning on line 213 in parcels/application_kernels/advection.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/advection.py#L213

Added line #L213 was not covered by tests

grid = fieldset.U.grid
if grid._gtype < 2:
Expand Down
42 changes: 42 additions & 0 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,48 @@ def computeTimeChunk(self, data, tindex):
self.filebuffers[tindex] = filebuffer
return data

def ravel_index(self, zi, yi, xi):
"""Return the flat index of the given grid points.

Parameters
----------
zi : int
x index
yi : int
y index
xi : int
z index

Returns
-------
int
flat index
"""
return xi + self.grid.xdim * (yi + self.grid.ydim * zi)

def unravel_index(self, ei):
"""Return the zi, yi, xi indices for a given flat index.

Parameters
----------
ei : int
The flat index to be unraveled.

Returns
-------
zi : int
The x index.
yi : int
The y index.
xi : int
The z index.
"""
zi = ei // (self.grid.xdim * self.grid.ydim)
ei = ei % (self.grid.xdim * self.grid.ydim)
yi = ei // self.grid.xdim
xi = ei % self.grid.xdim
return zi, yi, xi


class VectorField:
"""Class VectorField stores 2 or 3 fields which defines together a vector field.
Expand Down
2 changes: 1 addition & 1 deletion parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
self._ncount = len(lon)

for v in self.ptype.variables:
if v.name in ["xi", "yi", "zi", "ti"]:
if v.name in ["ei", "ti"]:
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype)
else:
self._data[v.name] = np.empty(self._ncount, dtype=v.dtype)
Expand Down
21 changes: 6 additions & 15 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,13 @@
type(self).ngrids.initial = numgrids
self.ngrids = type(self).ngrids.initial
if self.ngrids >= 0:
for index in ["xi", "yi", "zi", "ti"]:
if index != "ti":
setattr(self, index, np.zeros(self.ngrids, dtype=np.int32))
else:
setattr(self, index, -1 * np.ones(self.ngrids, dtype=np.int32))
self.ei = np.zeros(self.ngrids, dtype=np.int32)
self.ti = -1 * np.ones(self.ngrids, dtype=np.int32)

Check warning on line 131 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L130-L131

Added lines #L130 - L131 were not covered by tests
super(type(self), self).__init__(*args, **kwargs)

array_class_vdict = {
"ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1),
"xi": Variable("xi", dtype=np.int32, to_write=False),
"yi": Variable("yi", dtype=np.int32, to_write=False),
"zi": Variable("zi", dtype=np.int32, to_write=False),
"ei": Variable("ei", dtype=np.int32, to_write=False),

Check warning on line 136 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L136

Added line #L136 was not covered by tests
"ti": Variable("ti", dtype=np.int32, to_write=False, initial=-1),
"__init__": ArrayClass_init,
}
Expand Down Expand Up @@ -436,7 +431,7 @@

# TODO: This method is only tested in tutorial notebook. Add unit test?
def populate_indices(self):
"""Pre-populate guesses of particle xi/yi indices using a kdtree.
"""Pre-populate guesses of particle ei (element id) indices using a kdtree.

This is only intended for curvilinear grids, where the initial index search
may be quite expensive.
Expand All @@ -454,10 +449,8 @@
_, idx_nan = tree.query(pts.astype(tree_data.dtype))

idx = np.where(IN)[0][idx_nan]
yi, xi = np.unravel_index(idx, grid.lon.shape)

self.particledata.data["xi"][:, i] = xi
self.particledata.data["yi"][:, i] = yi
self.particledata.data["ei"][:, i] = idx # assumes that we are in the surface layer (zi=0)

Check warning on line 453 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L453

Added line #L453 was not covered by tests

@classmethod
def from_list(
Expand Down Expand Up @@ -725,9 +718,7 @@
elif (
v.name
not in [
"xi",
"yi",
"zi",
"ei",

Check warning on line 721 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L721

Added line #L721 was not covered by tests
"ti",
"dt",
"depth",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ def test_fieldset_write(tmp_zarrfile):

def UpdateU(particle, fieldset, time): # pragma: no cover
tmp1, tmp2 = fieldset.UV[particle]
fieldset.U.data[particle.ti, particle.yi, particle.xi] += 1
_, yi, xi = fieldset.U.unravel_index(particle.ei)
fieldset.U.data[particle.ti, yi, xi] += 1
fieldset.U.grid.time[0] = time

pset = ParticleSet(fieldset, pclass=Particle, lon=5, lat=5)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_fieldset_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def test_verticalsampling(zdir):
fieldset = FieldSet.from_data(data, dimensions, mesh="flat")
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0, depth=0.7 * zdir)
pset.execute(AdvectionRK4, dt=1.0, runtime=1.0)
assert pset[0].zi == [2]
zi, yi, xi = fieldset.U.unravel_index(pset[0].ei)
assert zi == [2]


def test_pset_from_field():
Expand Down
13 changes: 8 additions & 5 deletions tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,14 @@ def sampleTemp(particle, fieldset, time): # pragma: no cover
pset.execute(AdvectionRK4 + pset.Kernel(sampleTemp), runtime=3, dt=1)

# check if particle xi and yi are different for the two grids
# assert np.all([pset.xi[i, 0] != pset.xi[i, 1] for i in range(3)])
# assert np.all([pset.yi[i, 0] != pset.yi[i, 1] for i in range(3)])
assert np.all([pset[i].xi[0] != pset[i].xi[1] for i in range(3)])
assert np.all([pset[i].yi[0] != pset[i].yi[1] for i in range(3)])

# xi check from unraveled index
assert np.all(
[fieldset.U.unravel_index(pset[i].ei)[2][0] != fieldset.U.unravel_index(pset[i].ei)[2][1] for i in range(3)]
)
# yi check from unraveled index
assert np.all(
[fieldset.U.unravel_index(pset[i].ei)[1][0] != fieldset.U.unravel_index(pset[i].ei)[1][1] for i in range(3)]
)
# advect without updating temperature to test particle deletion
pset.remove_indices(np.array([1]))
pset.execute(AdvectionRK4, runtime=1, dt=1)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def Get_XiYi(particle, fieldset, time): # pragma: no cover
and that the first outputted value is zero.
Be careful when using multiple grids, as the index may be different for the grids.
"""
particle.pxi0 = particle.xi[0]
particle.pxi1 = particle.xi[1]
particle.pyi = particle.yi[0]
particle.pxi0 = fieldset.U.unravel_index(particle.ei[0])[2]
particle.pxi1 = fieldset.U.unravel_index(particle.ei[1])[2]
particle.pyi = fieldset.U.unravel_index(particle.ei[0])[1]

def SampleP(particle, fieldset, time): # pragma: no cover
if time > 5 * 3600:
Expand Down
Loading