Skip to content

dask performance #65

@florianboergel

Description

@florianboergel

Again, it must be verified, but the dask documentation says one should avoid calling multiple delayed functions. That is why the @dask_delayed is removed in calc_clim(). Instead, I only apply this now for calc_tresh and calc_seas(), which I know call independently.

# Loop over each cell to calculate climatologies, main functions
    # are delayed, so loop is automatically run in parallel

        # for c in ts.cell:
        #     climls.append(
        #         calc_clim(
        #             ts.sel(cell=c),
        #             tdim,
        #             pctile,
        #             windowHalfWidth,
        #             smoothPercentile,
        #             smoothPercentileWidth,
        #             tstep,
        #             skipna,
        #         )
        #     )

        thresClimYearDelayed = []
        seasClimYearDelayed = []

        for c in ts.cell:
            thresClimYearDelayed.append(calculate_thresh(ts.sel(cell=c), pctile, skipna,
                                                         tstep, windowHalfWidth, tdim,
                                                         smoothPercentile, smoothPercentileWidth))
            seasClimYearDelayed.append(calculate_seas(ts.sel(cell=c), skipna, tstep, windowHalfWidth, tdim,
                                                      smoothPercentile, smoothPercentileWidth))

    thresClimYear = dask.compute(*thresClimYearDelayed)
    seasClimYear = dask.compute(*seasClimYearDelayed)

    results = [thresClimYear, seasClimYear]

The different structure of results needs to be accounted for below.

To make sure I only call one delayed function I also removed the dask_delayed tag for runavg so that calculate_seas looks like.

@dask.delayed(nout=1)
def calculate_thresh(ts, pctile, skipna, tstep, windowHalfWidth, tdim, smoothPercentile, smoothPercentileWidth):
    """Calculate threshold for one cell grid at the time

    Parameters
    ----------
    twindow: xarray DataArray
        Stacked array timeseries with new 'z' dimension representing
        a window of width 2*w+1
    pctile: int
        Threshold percentile used to detect events
    skipna: bool
        If True percentile and mean function will use skipna=True.
        Using skipna option is much slower

    Returns
    -------
    thresh_climYear: xarray DataArray
        Climatological threshold
    """
    twindow = window_roll(ts, windowHalfWidth, tdim)

    thresh_climYear = twindow.groupby("doy").quantile(
        pctile / 100.0, dim="z", skipna=skipna
    )
    # calculate value for 29 Feb from mean of 28-29 feb and 1 Mar
    if tstep is False:
        thresh_climYear = thresh_climYear.where(
            thresh_climYear.doy != 60, feb29(thresh_climYear)
        )

    if smoothPercentile:
        thresh_climYear = runavg(thresh_climYear, smoothPercentileWidth)
        
    thresh_climYear = thresh_climYear.chunk({"doy": -1})
    return thresh_climYear

If you think this make sense, I can also make a pull request to verify the changes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions