Skip to content

Commit 827047b

Browse files
committed
Bug fix in exporting extra variables, added MCMC stats
1 parent 29a05e4 commit 827047b

4 files changed

Lines changed: 147 additions & 4 deletions

File tree

pygem/bin/run/run_calibration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
single_flowline_glacier_directory,
4747
single_flowline_glacier_directory_with_calving,
4848
)
49+
from pygem.utils.stats import mcmc_stats
4950

5051
# from oggm.core import climate
5152
# from oggm.core.flowline import FluxBasedModel
@@ -2135,6 +2136,9 @@ def must_melt(kp, tbias, ddfsnow, **kwargs):
21352136
modelprms_export['mb_obs_mwea_err'] = [float(mb_obs_mwea_err)]
21362137
modelprms_export['priors'] = priors
21372138

2139+
# compute stats on mcmc parameters
2140+
modelprms_export = mcmc_stats(modelprms_export)
2141+
21382142
modelprms_fn = glacier_str + '-modelprms_dict.json'
21392143
modelprms_fp = [
21402144
(

pygem/bin/run/run_simulation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def run(list_packed_vars):
606606
rgiid = main_glac_rgi.loc[main_glac_rgi.index.values[glac], 'RGIId']
607607

608608
try:
609-
# for batman in [0]:
609+
# for batman in [0]:
610610

611611
# ===== Load glacier data: area (km2), ice thickness (m), width (km) =====
612612
if (
@@ -1779,6 +1779,7 @@ def run(list_packed_vars):
17791779
sim_endyear=args.sim_endyear,
17801780
option_calibration=args.option_calibration,
17811781
option_bias_adjustment=args.option_bias_adjustment,
1782+
extra_vars=args.export_extra_vars,
17821783
)
17831784
for n_iter in range(nsims):
17841785
# pass model params for iteration and update output dataset model params
@@ -1878,6 +1879,7 @@ def run(list_packed_vars):
18781879
sim_endyear=args.sim_endyear,
18791880
option_calibration=args.option_calibration,
18801881
option_bias_adjustment=args.option_bias_adjustment,
1882+
extra_vars=args.export_extra_vars,
18811883
)
18821884
# create and return xarray dataset
18831885
output_stats.create_xr_ds()

pygem/output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class single_glacier:
9090
sim_endyear: int
9191
option_calibration: str
9292
option_bias_adjustment: str
93+
extra_vars: bool = False
9394

9495
def __post_init__(self):
9596
"""
@@ -492,7 +493,7 @@ def _update_dicts(self):
492493
}
493494

494495
# optionally store extra variables
495-
if pygem_prms['sim']['out']['export_extra_vars']:
496+
if self.extra_vars:
496497
self.output_coords_dict['glac_prec_monthly'] = collections.OrderedDict(
497498
[('glac', self.glac_values), ('time', self.time_values)]
498499
)
@@ -791,8 +792,8 @@ class binned_stats(single_glacier):
791792
Flag indicating whether additional binned components are included in the dataset.
792793
"""
793794

794-
nbins: int
795-
binned_components: bool
795+
nbins: int = 0
796+
binned_components: bool = False
796797

797798
def __post_init__(self):
798799
"""

pygem/utils/stats.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Python Glacier Evolution Model (PyGEM)
3+
4+
copyright © 2018 David Rounce <[email protected]>
5+
6+
Distrubted under the MIT lisence
7+
8+
Model statistics module
9+
"""
10+
11+
import numpy as np
12+
import arviz as az
13+
14+
def effective_n(x):
15+
"""
16+
Compute the effective sample size of a trace.
17+
18+
Takes the trace and computes the effective sample size
19+
according to its detrended autocorrelation.
20+
21+
Parameters
22+
----------
23+
x : list or array of chain samples
24+
25+
Returns
26+
-------
27+
effective_n : int
28+
effective sample size
29+
"""
30+
if len(set(x)) == 1:
31+
return 1
32+
try:
33+
# detrend trace using mean to be consistent with statistics
34+
# definition of autocorrelation
35+
x = np.asarray(x)
36+
x = (x - x.mean())
37+
# compute autocorrelation (note: only need second half since
38+
# they are symmetric)
39+
rho = np.correlate(x, x, mode='full')
40+
rho = rho[len(rho)//2:]
41+
# normalize the autocorrelation values
42+
# note: rho[0] is the variance * n_samples, so this is consistent
43+
# with the statistics definition of autocorrelation on wikipedia
44+
# (dividing by n_samples gives you the expected value).
45+
rho_norm = rho / rho[0]
46+
# Iterate until sum of consecutive estimates of autocorrelation is
47+
# negative to avoid issues with the sum being -0.5, which returns an
48+
# effective_n of infinity
49+
negative_autocorr = False
50+
t = 1
51+
n = len(x)
52+
while not negative_autocorr and (t < n):
53+
if not t % 2:
54+
negative_autocorr = sum(rho_norm[t-1:t+1]) < 0
55+
t += 1
56+
return int(n / (1 + 2*rho_norm[1:t].sum()))
57+
except:
58+
return None
59+
60+
61+
def mcmc_stats(chains_dict,
62+
params=['tbias','kp','ddfsnow','ddfice','rhoabl', 'rhoacc','mb_mwea']):
63+
"""
64+
Compute per-chain and overall summary stats for MCMC samples.
65+
66+
Parameters
67+
----------
68+
chains_dict : dict
69+
Dictionary with structure:
70+
{
71+
"param1": {
72+
"chain1": [...],
73+
"chain2": [...],
74+
...
75+
},
76+
...
77+
}
78+
79+
Returns
80+
-------
81+
summary_stats : dict
82+
Dictionary with structure:
83+
{
84+
"param1": {
85+
"mean": [...], # per chain
86+
"std": [...],
87+
"median": [...],
88+
"q025": [...],
89+
"q975": [...],
90+
"ess": ..., # overall
91+
"r_hat": ... # overall
92+
},
93+
...
94+
}
95+
"""
96+
summary_stats = {}
97+
98+
for param, chains in chains_dict.items():
99+
if param not in params:
100+
continue
101+
102+
# Stack chains into array: shape (n_chains, n_samples)
103+
chain_names = sorted(chains) # ensure consistent order
104+
samples = np.array([chains[c] for c in chain_names])
105+
106+
# Per-chain stats
107+
means = np.mean(samples, axis=1).tolist()
108+
stds = np.std(samples, axis=1, ddof=1).tolist()
109+
medians = np.median(samples, axis=1).tolist()
110+
q25 = np.quantile(samples, 0.25, axis=1).tolist()
111+
q75 = np.quantile(samples, 0.75, axis=1).tolist()
112+
ess = [effective_n(x) for x in samples]
113+
# Overall stats (R-hat)
114+
if samples.shape[0] > 1:
115+
# calculate the gelman-rubin stat for each variable across all chains
116+
# pass chains as 2d array to arviz using the from_dict() method
117+
# convert the chains into an InferenceData object
118+
idata = az.from_dict(posterior={param: samples})
119+
# calculate the Gelman-Rubin statistic (rhat)
120+
r_hat = float(az.rhat(idata).to_array().values[0])
121+
else :
122+
r_hat = None
123+
124+
summary_stats[param] = {
125+
"mean": means,
126+
"std": stds,
127+
"median": medians,
128+
"q25": q25,
129+
"q75": q75,
130+
"ess": ess,
131+
"r_hat": r_hat
132+
}
133+
134+
chains_dict['_summary_stats_'] = summary_stats
135+
136+
return chains_dict

0 commit comments

Comments
 (0)