diff --git a/pygem/plot/graphics.py b/pygem/plot/graphics.py index 7194dc6a..b70d4a31 100644 --- a/pygem/plot/graphics.py +++ b/pygem/plot/graphics.py @@ -274,12 +274,12 @@ def plot_mcmc_chain( continue # stack predictions first (shape: n_steps x ... x ...) - may end up being 2d or 3d - pred_primes = torch.stack(pred_primes[key]).numpy() - pred_chain = torch.stack(pred_chain[key]).numpy() + pred_primes_key = torch.stack(pred_primes[key]).numpy() + pred_chain_key = torch.stack(pred_chain[key]).numpy() # flatten all axes except the first (n_steps) -> 2D array (n_steps, M) - pred_primes_flat = pred_primes.reshape(pred_primes.shape[0], -1) - pred_chain_flat = pred_chain.reshape(pred_chain.shape[0], -1) + pred_primes_flat = pred_primes_key.reshape(pred_primes_key.shape[0], -1) + pred_chain_flat = pred_chain_key.reshape(pred_chain_key.shape[0], -1) # make obs array broadcastable (flatten if needed) obs_vals_flat = np.ravel(np.array(obs[key][0])) @@ -291,6 +291,9 @@ def plot_mcmc_chain( axes[nparams].plot(mean_resid_primes, '.', ms=ms, c='tab:blue') axes[nparams].plot(mean_resid_chain, '.', ms=ms, c='tab:orange') + axes[nparams].text( + 0.02, 0.02, key, transform=axes[nparams].transAxes, fontsize=fontsize, va='bottom', ha='left' + ) if key == 'elev_change_1d': axes[nparams].set_ylabel(r'$\overline{\hat{dh} - dh}$', fontsize=fontsize) else: