1+ """
2+ Python Glacier Evolution Model (PyGEM)
3+
4+ copyright © 2018 David Rounce <[email protected] > 5+
6+ Distrubted under the MIT lisence
7+
8+ Model statistics module
9+ """
10+
11+ import numpy as np
12+ import arviz as az
13+
14+ def effective_n (x ):
15+ """
16+ Compute the effective sample size of a trace.
17+
18+ Takes the trace and computes the effective sample size
19+ according to its detrended autocorrelation.
20+
21+ Parameters
22+ ----------
23+ x : list or array of chain samples
24+
25+ Returns
26+ -------
27+ effective_n : int
28+ effective sample size
29+ """
30+ if len (set (x )) == 1 :
31+ return 1
32+ try :
33+ # detrend trace using mean to be consistent with statistics
34+ # definition of autocorrelation
35+ x = np .asarray (x )
36+ x = (x - x .mean ())
37+ # compute autocorrelation (note: only need second half since
38+ # they are symmetric)
39+ rho = np .correlate (x , x , mode = 'full' )
40+ rho = rho [len (rho )// 2 :]
41+ # normalize the autocorrelation values
42+ # note: rho[0] is the variance * n_samples, so this is consistent
43+ # with the statistics definition of autocorrelation on wikipedia
44+ # (dividing by n_samples gives you the expected value).
45+ rho_norm = rho / rho [0 ]
46+ # Iterate until sum of consecutive estimates of autocorrelation is
47+ # negative to avoid issues with the sum being -0.5, which returns an
48+ # effective_n of infinity
49+ negative_autocorr = False
50+ t = 1
51+ n = len (x )
52+ while not negative_autocorr and (t < n ):
53+ if not t % 2 :
54+ negative_autocorr = sum (rho_norm [t - 1 :t + 1 ]) < 0
55+ t += 1
56+ return int (n / (1 + 2 * rho_norm [1 :t ].sum ()))
57+ except :
58+ return None
59+
60+
61+ def mcmc_stats (chains_dict ,
62+ params = ['tbias' ,'kp' ,'ddfsnow' ,'ddfice' ,'rhoabl' , 'rhoacc' ,'mb_mwea' ]):
63+ """
64+ Compute per-chain and overall summary stats for MCMC samples.
65+
66+ Parameters
67+ ----------
68+ chains_dict : dict
69+ Dictionary with structure:
70+ {
71+ "param1": {
72+ "chain1": [...],
73+ "chain2": [...],
74+ ...
75+ },
76+ ...
77+ }
78+
79+ Returns
80+ -------
81+ summary_stats : dict
82+ Dictionary with structure:
83+ {
84+ "param1": {
85+ "mean": [...], # per chain
86+ "std": [...],
87+ "median": [...],
88+ "q025": [...],
89+ "q975": [...],
90+ "ess": ..., # overall
91+ "r_hat": ... # overall
92+ },
93+ ...
94+ }
95+ """
96+ summary_stats = {}
97+
98+ for param , chains in chains_dict .items ():
99+ if param not in params :
100+ continue
101+
102+ # Stack chains into array: shape (n_chains, n_samples)
103+ chain_names = sorted (chains ) # ensure consistent order
104+ samples = np .array ([chains [c ] for c in chain_names ])
105+
106+ # Per-chain stats
107+ means = np .mean (samples , axis = 1 ).tolist ()
108+ stds = np .std (samples , axis = 1 , ddof = 1 ).tolist ()
109+ medians = np .median (samples , axis = 1 ).tolist ()
110+ q25 = np .quantile (samples , 0.25 , axis = 1 ).tolist ()
111+ q75 = np .quantile (samples , 0.75 , axis = 1 ).tolist ()
112+ ess = [effective_n (x ) for x in samples ]
113+ # Overall stats (R-hat)
114+ if samples .shape [0 ] > 1 :
115+ # calculate the gelman-rubin stat for each variable across all chains
116+ # pass chains as 2d array to arviz using the from_dict() method
117+ # convert the chains into an InferenceData object
118+ idata = az .from_dict (posterior = {param : samples })
119+ # calculate the Gelman-Rubin statistic (rhat)
120+ r_hat = float (az .rhat (idata ).to_array ().values [0 ])
121+ else :
122+ r_hat = None
123+
124+ summary_stats [param ] = {
125+ "mean" : means ,
126+ "std" : stds ,
127+ "median" : medians ,
128+ "q25" : q25 ,
129+ "q75" : q75 ,
130+ "ess" : ess ,
131+ "r_hat" : r_hat
132+ }
133+
134+ chains_dict ['_summary_stats_' ] = summary_stats
135+
136+ return chains_dict
0 commit comments