Skip to content

Commit 598e267

Browse files
veni-vidi-vici-dormivipre-commit-ci[bot]
andauthoredDec 9, 2024··
Let fit_harmonic_model return residuals instead of predictions (#574)
* let fit_harmonic_model return residuals instead of predictions * update calibration files (drop residuals before saving) * CHANGELOG * adjust example --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 23a8ad3 commit 598e267

File tree

8 files changed

+73
-57
lines changed

8 files changed

+73
-57
lines changed
 

‎CHANGELOG.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ Harmonic model
178178
`#424 <https://github.com/MESMER-group/mesmer/pull/424>`_,
179179
`#433 <https://github.com/MESMER-group/mesmer/pull/433>`_,
180180
`#512 <https://github.com/MESMER-group/mesmer/pull/512>`_, and
181-
`#512 <https://github.com/MESMER-group/mesmer/pull/512>`_).
181+
`#574 <https://github.com/MESMER-group/mesmer/pull/574>`_).
182182
- add tests (
183183
`#431 <https://github.com/MESMER-group/mesmer/pull/431>`_, and
184184
`#458 <https://github.com/MESMER-group/mesmer/pull/458>`_)

‎examples/example_mesmer_m.ipynb

+37-28
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,20 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
16-
"metadata": {},
17-
"outputs": [],
15+
"execution_count": 1,
16+
"metadata": {},
17+
"outputs": [
18+
{
19+
"data": {
20+
"text/plain": [
21+
"'0.10.0.post1.dev143+gfc10078.d20241008'"
22+
]
23+
},
24+
"execution_count": 1,
25+
"metadata": {},
26+
"output_type": "execute_result"
27+
}
28+
],
1829
"source": [
1930
"import importlib\n",
2031
"\n",
@@ -41,7 +52,7 @@
4152
},
4253
{
4354
"cell_type": "code",
44-
"execution_count": null,
55+
"execution_count": 2,
4556
"metadata": {},
4657
"outputs": [],
4758
"source": [
@@ -52,7 +63,7 @@
5263
},
5364
{
5465
"cell_type": "code",
55-
"execution_count": null,
66+
"execution_count": 3,
5667
"metadata": {},
5768
"outputs": [],
5869
"source": [
@@ -80,7 +91,7 @@
8091
},
8192
{
8293
"cell_type": "code",
83-
"execution_count": null,
94+
"execution_count": 4,
8495
"metadata": {},
8596
"outputs": [],
8697
"source": [
@@ -120,7 +131,7 @@
120131
},
121132
{
122133
"cell_type": "code",
123-
"execution_count": null,
134+
"execution_count": 5,
124135
"metadata": {},
125136
"outputs": [],
126137
"source": [
@@ -140,7 +151,7 @@
140151
},
141152
{
142153
"cell_type": "code",
143-
"execution_count": null,
154+
"execution_count": 6,
144155
"metadata": {},
145156
"outputs": [],
146157
"source": [
@@ -153,7 +164,7 @@
153164
},
154165
{
155166
"cell_type": "code",
156-
"execution_count": null,
167+
"execution_count": 7,
157168
"metadata": {},
158169
"outputs": [],
159170
"source": [
@@ -174,7 +185,7 @@
174185
},
175186
{
176187
"cell_type": "code",
177-
"execution_count": null,
188+
"execution_count": 8,
178189
"metadata": {},
179190
"outputs": [],
180191
"source": [
@@ -196,17 +207,15 @@
196207
},
197208
{
198209
"cell_type": "code",
199-
"execution_count": null,
210+
"execution_count": 10,
200211
"metadata": {},
201212
"outputs": [],
202213
"source": [
203-
"resids_after_hm = tas_stacked_m - harmonic_model_fit.predictions\n",
204-
"\n",
205214
"pt_coefficients = mesmer.stats.fit_yeo_johnson_transform(\n",
206-
" tas_stacked_y.tas, resids_after_hm.tas\n",
215+
" tas_stacked_y.tas, harmonic_model_fit.residuals\n",
207216
")\n",
208217
"transformed_hm_resids = mesmer.stats.yeo_johnson_transform(\n",
209-
" tas_stacked_y.tas, resids_after_hm.tas, pt_coefficients\n",
218+
" tas_stacked_y.tas, harmonic_model_fit.residuals, pt_coefficients\n",
210219
")"
211220
]
212221
},
@@ -221,7 +230,7 @@
221230
},
222231
{
223232
"cell_type": "code",
224-
"execution_count": null,
233+
"execution_count": 11,
225234
"metadata": {},
226235
"outputs": [],
227236
"source": [
@@ -241,7 +250,7 @@
241250
},
242251
{
243252
"cell_type": "code",
244-
"execution_count": null,
253+
"execution_count": 12,
245254
"metadata": {},
246255
"outputs": [],
247256
"source": [
@@ -282,7 +291,7 @@
282291
},
283292
{
284293
"cell_type": "code",
285-
"execution_count": null,
294+
"execution_count": 13,
286295
"metadata": {},
287296
"outputs": [],
288297
"source": [
@@ -309,7 +318,7 @@
309318
},
310319
{
311320
"cell_type": "code",
312-
"execution_count": null,
321+
"execution_count": 14,
313322
"metadata": {},
314323
"outputs": [],
315324
"source": [
@@ -329,7 +338,7 @@
329338
},
330339
{
331340
"cell_type": "code",
332-
"execution_count": null,
341+
"execution_count": 15,
333342
"metadata": {},
334343
"outputs": [],
335344
"source": [
@@ -355,7 +364,7 @@
355364
},
356365
{
357366
"cell_type": "code",
358-
"execution_count": null,
367+
"execution_count": 16,
359368
"metadata": {},
360369
"outputs": [],
361370
"source": [
@@ -372,7 +381,7 @@
372381
},
373382
{
374383
"cell_type": "code",
375-
"execution_count": null,
384+
"execution_count": 17,
376385
"metadata": {},
377386
"outputs": [],
378387
"source": [
@@ -383,7 +392,7 @@
383392
},
384393
{
385394
"cell_type": "code",
386-
"execution_count": null,
395+
"execution_count": 18,
387396
"metadata": {},
388397
"outputs": [],
389398
"source": [
@@ -401,7 +410,7 @@
401410
},
402411
{
403412
"cell_type": "code",
404-
"execution_count": null,
413+
"execution_count": 19,
405414
"metadata": {},
406415
"outputs": [],
407416
"source": [
@@ -413,7 +422,7 @@
413422
},
414423
{
415424
"cell_type": "code",
416-
"execution_count": null,
425+
"execution_count": 20,
417426
"metadata": {},
418427
"outputs": [],
419428
"source": [
@@ -430,7 +439,7 @@
430439
},
431440
{
432441
"cell_type": "code",
433-
"execution_count": null,
442+
"execution_count": 21,
434443
"metadata": {},
435444
"outputs": [],
436445
"source": [
@@ -460,7 +469,7 @@
460469
},
461470
{
462471
"cell_type": "code",
463-
"execution_count": null,
472+
"execution_count": 22,
464473
"metadata": {},
465474
"outputs": [],
466475
"source": [
@@ -477,7 +486,7 @@
477486
},
478487
{
479488
"cell_type": "code",
480-
"execution_count": null,
489+
"execution_count": 23,
481490
"metadata": {},
482491
"outputs": [],
483492
"source": [

‎mesmer/stats/_harmonic_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ def fit_harmonic_model(yearly_predictor, monthly_target, max_order=6, time_dim="
265265
Returns
266266
-------
267267
data_vars : `xr.Dataset`
268-
Dataset containing the selected order of Fourier Series (selected_order),
269-
the estimated coefficients of the Fourier Series (coeffs) and the resulting
270-
predictions for monthly values (predictions).
268+
Dataset containing the selected order of Fourier Series (`selected_order`),
269+
the estimated coefficients of the Fourier Series (`coeffs`) and the resulting
270+
residuals of the model (`residuals`).
271271
272272
"""
273273

@@ -300,12 +300,12 @@ def fit_harmonic_model(yearly_predictor, monthly_target, max_order=6, time_dim="
300300

301301
coeffs = coeffs.assign_coords({"coeff": np.arange(coeffs.sizes["coeff"])})
302302

303-
preds = yearly_predictor + preds
303+
resids = monthly_target - (yearly_predictor + preds)
304304

305305
data_vars = {
306306
"selected_order": selected_order,
307307
"coeffs": coeffs,
308-
"predictions": preds.transpose(time_dim, ...),
308+
"residuals": resids.transpose(time_dim, ...),
309309
}
310310

311311
return xr.Dataset(data_vars)

‎tests/integration/test_calibrate_mesmer_m.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,12 @@ def mask_and_stack(ds, threshold_land):
8989
)
9090

9191
# train power transformer
92-
resids_after_hm = tas_stacked_m - harmonic_model_fit.predictions
9392
pt_coefficients = mesmer.stats.fit_yeo_johnson_transform(
9493
tas_stacked_y.tas,
95-
resids_after_hm.tas,
94+
harmonic_model_fit.residuals,
9695
)
9796
transformed_hm_resids = mesmer.stats.yeo_johnson_transform(
98-
tas_stacked_y.tas, resids_after_hm.tas, pt_coefficients
97+
tas_stacked_y.tas, harmonic_model_fit.residuals, pt_coefficients
9998
)
10099

101100
# fit cyclo-stationary AR(1) process
@@ -122,6 +121,11 @@ def mask_and_stack(ds, threshold_land):
122121

123122
# save params
124123
if update_expected_files:
124+
# drop unnecessary variables
125+
harmonic_model_fit = harmonic_model_fit.drop_vars(["residuals", "time"])
126+
AR1_fit = AR1_fit.drop_vars(["residuals", "time"])
127+
128+
# save
125129
harmonic_model_fit.to_netcdf(
126130
TEST_PATH
127131
/ "harmonic_model"

‎tests/unit/test_harmonic_model.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def test_fit_harmonic_model():
158158
# test if the model can recover the monthly target from perfect fourier series
159159
result = mesmer.stats.fit_harmonic_model(yearly_predictor, monthly_target)
160160
np.testing.assert_equal(result.selected_order.values, orders)
161-
xr.testing.assert_allclose(result["predictions"], monthly_target)
161+
xr.testing.assert_allclose(
162+
result.residuals, xr.zeros_like(monthly_target), atol=1e-6
163+
)
162164

163165
# test if the model can recover the underlying cycle with noise on top of monthly target
164166
rng = np.random.default_rng(0)
@@ -167,35 +169,36 @@ def test_fit_harmonic_model():
167169
)
168170

169171
result = mesmer.stats.fit_harmonic_model(yearly_predictor, noisy_monthly_target)
170-
xr.testing.assert_allclose(result["predictions"], monthly_target, atol=0.1)
172+
predictions = mesmer.stats.predict_harmonic_model(
173+
yearly_predictor, result.coeffs, time=monthly_time
174+
)
175+
xr.testing.assert_allclose(predictions, monthly_target, atol=0.1)
171176

172177
# compare numerically one cell of one year
173178
expected = np.array(
174179
[
175-
7.324277,
176-
9.966644,
177-
9.972146,
178-
7.33931,
179-
2.7736,
180-
-2.501604,
181-
-7.072816,
182-
-9.715184,
183-
-9.720686,
184-
-7.087849,
185-
-2.52214,
186-
2.753065,
180+
0.014026,
181+
0.131156,
182+
-0.232648,
183+
0.040157,
184+
0.088749,
185+
-0.102724,
186+
-0.066836,
187+
0.133832,
188+
0.180308,
189+
-0.12783,
190+
-0.042045,
191+
0.160109,
187192
]
188193
)
189194

190-
result_comp = result.predictions.isel(cells=0, time=slice(0, 12)).values
195+
result_comp = result.residuals.isel(cells=0, time=slice(0, 12)).values
191196
np.testing.assert_allclose(result_comp, expected, atol=1e-6)
192197

193-
# ensure coeffs and predictions are consistent
194-
expected = mesmer.stats.predict_harmonic_model(
195-
yearly_predictor, result.coeffs, result.time
196-
)
198+
# ensure predictons and residuals are consistent
199+
expected = noisy_monthly_target - predictions
197200

198-
xr.testing.assert_equal(expected, result.predictions)
201+
xr.testing.assert_equal(expected, result.residuals)
199202

200203

201204
def test_fit_harmonic_model_checks():

0 commit comments

Comments
 (0)
Please sign in to comment.