From a2ce2d48be948eb4976b7e7633759a914012d7a1 Mon Sep 17 00:00:00 2001 From: Emmaaycoberry Date: Tue, 10 Feb 2026 17:15:48 +0100 Subject: [PATCH 1/4] Add a python script to plot the results --- shine/plot_shine_results.py | 506 ++++++++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100755 shine/plot_shine_results.py diff --git a/shine/plot_shine_results.py b/shine/plot_shine_results.py new file mode 100755 index 0000000..85906b1 --- /dev/null +++ b/shine/plot_shine_results.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python3 +""" +SHINE Results Visualization Script + +This script loads and visualizes results from a SHINE inference run. +It generates comprehensive diagnostic plots including: +- Observation visualization (image, PSF, noise map) +- Posterior distributions for all parameters +- Corner plot with confidence intervals +- Shear parameter analysis (if applicable) +- Summary statistics +- Trace plots (if MCMC chains present) +- Parameter correlation matrix + +Usage: + python plot_shine_results.py --output my_output/ +""" + +import argparse +import sys +from pathlib import Path + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr + + +def setup_plot_style(): + """Configure matplotlib plot style.""" + plt.rcParams['figure.figsize'] = (12, 8) + plt.rcParams['font.size'] = 10 + plt.style.use('seaborn-v0_8-darkgrid') + + +def load_data(output_dir): + """Load observation and posterior data from the output directory. + + Args: + output_dir: Path to directory containing observation.npz and posterior.nc + + Returns: + tuple: (obs_data, idata, posterior) + """ + output_path = Path(output_dir) + + # Load observation data + obs_file = output_path / 'observation.npz' + if not obs_file.exists(): + raise FileNotFoundError(f"observation.npz not found in {output_dir}") + obs_data = np.load(obs_file) + print(f"Observation data loaded from {obs_file}") + print(f"Available keys: {list(obs_data.keys())}") + + # Load posterior estimates + posterior_file = output_path / 'posterior.nc' + if not posterior_file.exists(): + raise FileNotFoundError(f"posterior.nc not found in {output_dir}") + idata = az.from_netcdf(posterior_file) + posterior = idata.posterior + print(f"\nPosterior data loaded from {posterior_file}") + print(f"Dataset structure:") + print(posterior) + + return obs_data, idata, posterior + + +def plot_observation(obs_data, output_dir): + """Visualize the observed galaxy image, PSF, and noise map. + + Args: + obs_data: Loaded observation data + output_dir: Directory to save the plot + """ + print("\n" + "="*70) + print("Plotting Observation") + print("="*70) + + image = obs_data.get('image', None) + psf = obs_data.get('psf', None) + noise_map = obs_data.get('noise_map', None) + + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot the galaxy image + if image is not None: + im1 = axes[0].imshow(image, origin='lower', cmap='viridis') + axes[0].set_title('Observed Galaxy Image', fontsize=14, fontweight='bold') + axes[0].set_xlabel('X pixel') + axes[0].set_ylabel('Y pixel') + plt.colorbar(im1, ax=axes[0], label='Flux') + axes[0].text(0.02, 0.98, f'Max: {image.max():.2e}\nMin: {image.min():.2e}', + transform=axes[0].transAxes, verticalalignment='top', + bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + # Plot the PSF + if psf is not None: + im2 = axes[1].imshow(psf, origin='lower', cmap='hot') + axes[1].set_title('PSF Model', fontsize=14, fontweight='bold') + axes[1].set_xlabel('X pixel') + axes[1].set_ylabel('Y pixel') + plt.colorbar(im2, ax=axes[1], label='Normalized Flux') + + # Plot the noise map + if noise_map is not None: + if noise_map.ndim == 0: # Scalar noise + axes[2].text(0.5, 0.5, f'Uniform Noise\nσ = {float(noise_map):.2e}', + ha='center', va='center', fontsize=16, + transform=axes[2].transAxes) + axes[2].set_title('Noise Map', fontsize=14, fontweight='bold') + axes[2].axis('off') + else: # Spatial noise map + im3 = axes[2].imshow(noise_map, origin='lower', cmap='plasma') + axes[2].set_title('Noise Map (σ)', fontsize=14, fontweight='bold') + axes[2].set_xlabel('X pixel') + axes[2].set_ylabel('Y pixel') + plt.colorbar(im3, ax=axes[2], label='Noise σ') + + output_file = Path(output_dir) / 'observation_visual.png' + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + print(f"✓ Observation visualization saved to {output_file}") + + +def plot_posterior_distributions(posterior, param_names, output_dir): + """Plot posterior distributions for all parameters. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + output_dir: Directory to save the plot + """ + print("\n" + "="*70) + print("Plotting Posterior Distributions") + print("="*70) + + n_params = len(param_names) + n_cols = min(3, n_params) + n_rows = int(np.ceil(n_params / n_cols)) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows)) + if n_params == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for idx, param in enumerate(param_names): + ax = axes[idx] + samples = posterior[param].values + + if samples.ndim > 1: + samples = samples.flatten() + + ax.hist(samples, bins=50, density=True, alpha=0.7, + color='steelblue', edgecolor='black') + + mean_val = np.mean(samples) + median_val = np.median(samples) + std_val = np.std(samples) + + ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, + label=f'Mean: {mean_val:.4f}') + ax.axvline(median_val, color='green', linestyle=':', linewidth=2, + label=f'Median: {median_val:.4f}') + + ax.set_xlabel(f'{param}', fontsize=12) + ax.set_ylabel('Density', fontsize=12) + ax.set_title(f'{param} Posterior\nσ = {std_val:.4f}', fontsize=12, fontweight='bold') + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + + # Hide empty subplots + for idx in range(n_params, len(axes)): + axes[idx].axis('off') + + output_file = Path(output_dir) / 'posterior_distributions.png' + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + print(f"✓ Posterior distributions saved to {output_file}") + + +def plot_corner(posterior, param_names, output_dir): + """Create corner plot with confidence intervals. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + output_dir: Directory to save the plot + """ + if len(param_names) <= 1: + print("\nCorner plot requires at least 2 parameters. Skipping.") + return + + print("\n" + "="*70) + print("Plotting Corner Plot") + print("="*70) + + try: + import corner + except ImportError: + print("Installing corner package...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "corner"]) + import corner + + # Prepare data: stack all parameters as columns + samples_array = np.column_stack([posterior[param].values.flatten() for param in param_names]) + + # Create corner plot with confidence intervals + fig = corner.corner( + samples_array, + labels=param_names, + quantiles=[0.16, 0.5, 0.84], # 16th, 50th, 84th percentiles (±1σ) + levels=(0.68, 0.95), # 68% and 95% confidence intervals + show_titles=True, + title_fmt='.4f', + smooth=1.0, + plot_datapoints=True, + plot_density=True, + fill_contours=True, + color='steelblue', + truth_color='red', + title_kwargs={"fontsize": 11}, + ) + + plt.suptitle('Corner Plot: Joint & Marginal Distributions', + fontsize=14, fontweight='bold', y=0.995) + + output_file = Path(output_dir) / 'corner_plot.png' + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + print(f"✓ Corner plot saved to {output_file}") + + +def plot_shear_analysis(posterior, param_names, output_dir): + """Create corner plot for shear parameters (g1, g2) only. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + output_dir: Directory to save the plot + """ + shear_params = [p for p in param_names if 'g1' in p.lower() or 'g2' in p.lower() or 'shear' in p.lower()] + + if not shear_params: + print("\nNo shear parameters found. Skipping shear analysis.") + return + + print("\n" + "="*70) + print("Plotting Shear Analysis") + print("="*70) + print(f"Found shear parameters: {shear_params}") + + g1_param = next((p for p in param_names if 'g1' in p.lower()), None) + g2_param = next((p for p in param_names if 'g2' in p.lower()), None) + + if not (g1_param and g2_param): + print("Could not identify both g1 and g2 parameters. Skipping.") + return + + # Import corner package + try: + import corner + except ImportError: + print("Installing corner package...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "corner"]) + import corner + + g1_samples = posterior[g1_param].values.flatten() + g2_samples = posterior[g2_param].values.flatten() + + # Prepare data: stack g1 and g2 as columns + samples_array = np.column_stack([g1_samples, g2_samples]) + + # Calculate statistics + g1_mean = np.mean(g1_samples) + g1_std = np.std(g1_samples) + g2_mean = np.mean(g2_samples) + g2_std = np.std(g2_samples) + + # Create corner plot for shear parameters only + fig = corner.corner( + samples_array, + labels=['g1', 'g2'], + quantiles=[0.16, 0.5, 0.84], # 16th, 50th, 84th percentiles (±1σ) + levels=(0.68, 0.95), # 68% and 95% confidence intervals + show_titles=True, + title_fmt='.4f', + smooth=1.0, + plot_datapoints=True, + plot_density=True, + fill_contours=True, + color='steelblue', + truth_color='red', + title_kwargs={"fontsize": 11}, + ) + + plt.suptitle('Shear Parameters Corner Plot (g1, g2)', + fontsize=14, fontweight='bold', y=0.995) + + output_file = Path(output_dir) / 'shear_analysis.png' + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + + print(f"\nShear estimates:") + print(f" g1 = {g1_mean:.6f} ± {g1_std:.6f}") + print(f" g2 = {g2_mean:.6f} ± {g2_std:.6f}") + print(f"✓ Shear analysis saved to {output_file}") + + +def print_summary_statistics(posterior, param_names): + """Print summary statistics for all parameters. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + """ + print("\n" + "="*70) + print("POSTERIOR SUMMARY STATISTICS") + print("="*70) + print(f"{'Parameter':<20} {'Mean':<12} {'Std':<12} {'Median':<12} {'95% CI':<20}") + print("-"*70) + + for param in param_names: + samples = posterior[param].values.flatten() + mean_val = np.mean(samples) + std_val = np.std(samples) + median_val = np.median(samples) + ci_low = np.percentile(samples, 2.5) + ci_high = np.percentile(samples, 97.5) + + print(f"{param:<20} {mean_val:<12.6f} {std_val:<12.6f} {median_val:<12.6f} [{ci_low:.6f}, {ci_high:.6f}]") + + print("="*70) + + +def plot_trace(posterior, param_names, output_dir): + """Plot trace plots for MCMC convergence diagnostics. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + output_dir: Directory to save the plot + """ + has_chains = any(dim in posterior.dims for dim in ['chain', 'draw', 'sample']) + + if not has_chains: + print("\nNo chain/draw dimensions found - likely MAP or point estimate.") + print("Skipping trace plots.") + return + + print("\n" + "="*70) + print("Plotting Trace Plots") + print("="*70) + + n_params = len(param_names) + fig, axes = plt.subplots(n_params, 1, figsize=(12, 3*n_params)) + + if n_params == 1: + axes = [axes] + + for idx, param in enumerate(param_names): + samples = posterior[param].values + + # Trace plot + if samples.ndim >= 2: + for chain in range(samples.shape[0]): + axes[idx].plot(samples[chain], alpha=0.7, label=f'Chain {chain}') + else: + axes[idx].plot(samples, alpha=0.7) + + axes[idx].set_ylabel(param, fontsize=11) + axes[idx].set_xlabel('Iteration', fontsize=11) + axes[idx].set_title(f'{param} - Trace', fontsize=12, fontweight='bold') + axes[idx].grid(True, alpha=0.3) + if samples.ndim >= 2 and samples.shape[0] <= 10: + axes[idx].legend(fontsize=8) + + output_file = Path(output_dir) / 'trace_plots.png' + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + print(f"✓ Trace plots saved to {output_file}") + + +def plot_correlation_matrix(posterior, param_names, output_dir): + """Plot parameter correlation matrix. + + Args: + posterior: Posterior dataset + param_names: List of parameter names + output_dir: Directory to save the plot + """ + if len(param_names) <= 1: + print("\nOnly one parameter - skipping correlation matrix.") + return + + print("\n" + "="*70) + print("Plotting Correlation Matrix") + print("="*70) + + # Create correlation matrix + data_matrix = np.column_stack([posterior[param].values.flatten() for param in param_names]) + corr_matrix = np.corrcoef(data_matrix.T) + + # Plot correlation matrix + fig, ax = plt.subplots(figsize=(10, 8)) + im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto') + + # Set ticks + ax.set_xticks(range(len(param_names))) + ax.set_yticks(range(len(param_names))) + ax.set_xticklabels(param_names, rotation=45, ha='right') + ax.set_yticklabels(param_names) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label('Correlation', fontsize=12) + + # Add correlation values + for i in range(len(param_names)): + for j in range(len(param_names)): + ax.text(j, i, f'{corr_matrix[i, j]:.2f}', + ha="center", va="center", color="black", fontsize=10) + + ax.set_title('Parameter Correlation Matrix', fontsize=14, fontweight='bold', pad=20) + + output_file = Path(output_dir) / 'correlation_matrix.png' + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + plt.close() + print(f"✓ Correlation matrix saved to {output_file}") + + +def main(): + """Main function to run the visualization pipeline.""" + parser = argparse.ArgumentParser( + description='Visualize SHINE inference results', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + parser.add_argument( + '--output', + type=str, + required=True, + help='Directory containing observation.npz and posterior.nc files (and where plots will be saved)' + ) + + args = parser.parse_args() + + # Verify output directory exists + output_dir = Path(args.output) + if not output_dir.exists(): + print(f"Error: Directory {output_dir} does not exist") + sys.exit(1) + + print("="*70) + print("SHINE RESULTS VISUALIZATION") + print("="*70) + print(f"Output directory: {output_dir.absolute()}") + + # Setup plotting style + setup_plot_style() + + # Load data + try: + obs_data, idata, posterior = load_data(output_dir) + except FileNotFoundError as e: + print(f"\nError: {e}") + sys.exit(1) + + # Get parameter names + param_names = list(posterior.data_vars) + print(f"\nInferred parameters: {param_names}") + + # Generate all plots + plot_observation(obs_data, output_dir) + plot_posterior_distributions(posterior, param_names, output_dir) + plot_corner(posterior, param_names, output_dir) + plot_shear_analysis(posterior, param_names, output_dir) + print_summary_statistics(posterior, param_names) + plot_trace(posterior, param_names, output_dir) + plot_correlation_matrix(posterior, param_names, output_dir) + + # Final summary + print("\n" + "="*70) + print("VISUALIZATION COMPLETE") + print("="*70) + print(f"\nAll plots saved to: {output_dir.absolute()}") + print("\nGenerated plots:") + print(" • observation_visual.png - Observed image, PSF, and noise map") + print(" • posterior_distributions.png - Posterior distributions for all parameters") + if len(param_names) > 1: + print(" • corner_plot.png - Corner plot with confidence intervals") + print(" • correlation_matrix.png - Parameter correlation matrix") + if any('g1' in p.lower() or 'g2' in p.lower() for p in param_names): + print(" • shear_analysis.png - Detailed shear parameter analysis") + if any(dim in posterior.dims for dim in ['chain', 'draw', 'sample']): + print(" • trace_plots.png - MCMC trace plots") + print("\n✓ All visualizations completed successfully!") + + +if __name__ == '__main__': + main() From dbdeaff04d0c2ef0753e7a84ed4a41b5fa666515 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:34:52 +0000 Subject: [PATCH 2/4] Fix code review issues in plot_shine_results.py - Add full PEP 484 type hints to all functions - Convert all docstrings to Google-style format with detailed Args, Returns, and Raises sections - Remove automatic package installation security issue: replaced with ImportError messages - User must now manually install 'corner' package if needed Addresses review feedback on PR #15. Co-authored-by: Francois Lanusse --- shine/plot_shine_results.py | 162 ++++++++++++++++++++++++------------ 1 file changed, 110 insertions(+), 52 deletions(-) diff --git a/shine/plot_shine_results.py b/shine/plot_shine_results.py index 85906b1..4e0a655 100755 --- a/shine/plot_shine_results.py +++ b/shine/plot_shine_results.py @@ -19,6 +19,7 @@ import argparse import sys from pathlib import Path +from typing import Any, Dict, List, Tuple import arviz as az import matplotlib.pyplot as plt @@ -26,21 +27,33 @@ import xarray as xr -def setup_plot_style(): - """Configure matplotlib plot style.""" +def setup_plot_style() -> None: + """Configure matplotlib plot style for all visualizations. + + Sets default figure size, font size, and applies seaborn darkgrid style + to all matplotlib plots generated by this script. + """ plt.rcParams['figure.figsize'] = (12, 8) plt.rcParams['font.size'] = 10 plt.style.use('seaborn-v0_8-darkgrid') -def load_data(output_dir): +def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Dataset]: """Load observation and posterior data from the output directory. - + Args: - output_dir: Path to directory containing observation.npz and posterior.nc - + output_dir (Path): Path to directory containing observation.npz and posterior.nc files. + Returns: - tuple: (obs_data, idata, posterior) + tuple: A tuple containing: + - obs_data (Dict[str, Any]): Loaded observation data from observation.npz containing + keys like 'image', 'psf', and 'noise_map'. + - idata (az.InferenceData): ArviZ InferenceData object loaded from posterior.nc. + - posterior (xr.Dataset): Posterior dataset extracted from idata containing parameter + samples from inference. + + Raises: + FileNotFoundError: If observation.npz or posterior.nc are not found in output_dir. """ output_path = Path(output_dir) @@ -65,12 +78,16 @@ def load_data(output_dir): return obs_data, idata, posterior -def plot_observation(obs_data, output_dir): +def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None: """Visualize the observed galaxy image, PSF, and noise map. - + + Creates a 3-panel figure showing the observed galaxy image, PSF model, and noise map. + Handles both scalar and spatial noise maps. Saves the visualization as 'observation_visual.png'. + Args: - obs_data: Loaded observation data - output_dir: Directory to save the plot + obs_data (Dict[str, Any]): Loaded observation data containing 'image', 'psf', and + 'noise_map' keys. + output_dir (Path): Directory where the plot will be saved. """ print("\n" + "="*70) print("Plotting Observation") @@ -123,13 +140,17 @@ def plot_observation(obs_data, output_dir): print(f"✓ Observation visualization saved to {output_file}") -def plot_posterior_distributions(posterior, param_names, output_dir): +def plot_posterior_distributions(posterior: xr.Dataset, param_names: List[str], output_dir: Path) -> None: """Plot posterior distributions for all parameters. - + + Creates histogram plots showing the posterior distribution for each parameter with + mean and median indicators. Arranges plots in a grid layout and saves as + 'posterior_distributions.png'. + Args: - posterior: Posterior dataset - param_names: List of parameter names - output_dir: Directory to save the plot + posterior (xr.Dataset): Posterior dataset containing parameter samples. + param_names (List[str]): List of parameter names to plot. + output_dir (Path): Directory where the plot will be saved. """ print("\n" + "="*70) print("Plotting Posterior Distributions") @@ -180,29 +201,37 @@ def plot_posterior_distributions(posterior, param_names, output_dir): print(f"✓ Posterior distributions saved to {output_file}") -def plot_corner(posterior, param_names, output_dir): +def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path) -> None: """Create corner plot with confidence intervals. - + + Generates a corner plot showing joint and marginal distributions for all parameters + with 68% and 95% confidence intervals. Requires at least 2 parameters. + Saves as 'corner_plot.png'. + Args: - posterior: Posterior dataset - param_names: List of parameter names - output_dir: Directory to save the plot + posterior (xr.Dataset): Posterior dataset containing parameter samples. + param_names (List[str]): List of parameter names to include in corner plot. + output_dir (Path): Directory where the plot will be saved. + + Raises: + ImportError: If the 'corner' package is not installed. Install it with: + pip install corner """ if len(param_names) <= 1: print("\nCorner plot requires at least 2 parameters. Skipping.") return - + print("\n" + "="*70) print("Plotting Corner Plot") print("="*70) - + try: import corner except ImportError: - print("Installing corner package...") - import subprocess - subprocess.check_call([sys.executable, "-m", "pip", "install", "corner"]) - import corner + raise ImportError( + "The 'corner' package is required for corner plots. " + "Install it with: pip install corner" + ) # Prepare data: stack all parameters as columns samples_array = np.column_stack([posterior[param].values.flatten() for param in param_names]) @@ -233,13 +262,20 @@ def plot_corner(posterior, param_names, output_dir): print(f"✓ Corner plot saved to {output_file}") -def plot_shear_analysis(posterior, param_names, output_dir): +def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_dir: Path) -> None: """Create corner plot for shear parameters (g1, g2) only. - + + Generates a specialized corner plot focusing on shear parameters (g1, g2) with + confidence intervals and prints summary statistics. Saves as 'shear_analysis.png'. + Args: - posterior: Posterior dataset - param_names: List of parameter names - output_dir: Directory to save the plot + posterior (xr.Dataset): Posterior dataset containing parameter samples. + param_names (List[str]): List of parameter names to search for shear parameters. + output_dir (Path): Directory where the plot will be saved. + + Raises: + ImportError: If the 'corner' package is not installed. Install it with: + pip install corner """ shear_params = [p for p in param_names if 'g1' in p.lower() or 'g2' in p.lower() or 'shear' in p.lower()] @@ -263,10 +299,10 @@ def plot_shear_analysis(posterior, param_names, output_dir): try: import corner except ImportError: - print("Installing corner package...") - import subprocess - subprocess.check_call([sys.executable, "-m", "pip", "install", "corner"]) - import corner + raise ImportError( + "The 'corner' package is required for shear analysis plots. " + "Install it with: pip install corner" + ) g1_samples = posterior[g1_param].values.flatten() g2_samples = posterior[g2_param].values.flatten() @@ -310,12 +346,15 @@ def plot_shear_analysis(posterior, param_names, output_dir): print(f"✓ Shear analysis saved to {output_file}") -def print_summary_statistics(posterior, param_names): +def print_summary_statistics(posterior: xr.Dataset, param_names: List[str]) -> None: """Print summary statistics for all parameters. - + + Prints a formatted table showing mean, standard deviation, median, and 95% credible + intervals for each parameter to the console. Statistics are not saved to a file. + Args: - posterior: Posterior dataset - param_names: List of parameter names + posterior (xr.Dataset): Posterior dataset containing parameter samples. + param_names (List[str]): List of parameter names to summarize. """ print("\n" + "="*70) print("POSTERIOR SUMMARY STATISTICS") @@ -336,13 +375,18 @@ def print_summary_statistics(posterior, param_names): print("="*70) -def plot_trace(posterior, param_names, output_dir): +def plot_trace(posterior: xr.Dataset, param_names: List[str], output_dir: Path) -> None: """Plot trace plots for MCMC convergence diagnostics. - + + Creates trace plots showing parameter values across MCMC iterations for each chain. + Useful for assessing convergence of MCMC sampling. Only generated if chain/draw + dimensions are present. Saves as 'trace_plots.png'. + Args: - posterior: Posterior dataset - param_names: List of parameter names - output_dir: Directory to save the plot + posterior (xr.Dataset): Posterior dataset containing parameter samples with + chain/draw dimensions. + param_names (List[str]): List of parameter names to plot. + output_dir (Path): Directory where the plot will be saved. """ has_chains = any(dim in posterior.dims for dim in ['chain', 'draw', 'sample']) @@ -385,13 +429,17 @@ def plot_trace(posterior, param_names, output_dir): print(f"✓ Trace plots saved to {output_file}") -def plot_correlation_matrix(posterior, param_names, output_dir): +def plot_correlation_matrix(posterior: xr.Dataset, param_names: List[str], output_dir: Path) -> None: """Plot parameter correlation matrix. - + + Creates a heatmap showing the correlation coefficients between all pairs of parameters. + Values range from -1 (perfect negative correlation) to +1 (perfect positive correlation). + Requires at least 2 parameters. Saves as 'correlation_matrix.png'. + Args: - posterior: Posterior dataset - param_names: List of parameter names - output_dir: Directory to save the plot + posterior (xr.Dataset): Posterior dataset containing parameter samples. + param_names (List[str]): List of parameter names to include in the correlation matrix. + output_dir (Path): Directory where the plot will be saved. """ if len(param_names) <= 1: print("\nOnly one parameter - skipping correlation matrix.") @@ -434,8 +482,18 @@ def plot_correlation_matrix(posterior, param_names, output_dir): print(f"✓ Correlation matrix saved to {output_file}") -def main(): - """Main function to run the visualization pipeline.""" +def main() -> None: + """Main function to run the visualization pipeline. + + Parses command-line arguments, loads observation and posterior data, and generates + all visualization plots including observation visuals, posterior distributions, + corner plots, shear analysis, trace plots, and correlation matrices. All plots + are saved to the specified output directory. + + Raises: + SystemExit: If the output directory does not exist or required data files + (observation.npz, posterior.nc) are not found. + """ parser = argparse.ArgumentParser( description='Visualize SHINE inference results', formatter_class=argparse.RawDescriptionHelpFormatter, From f3bb9df6cf1958f8e1fc4674d9c1c94bd3e08f7f Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:42:49 +0000 Subject: [PATCH 3/4] Add corner package as a dependency Added corner>=2.0.0 to project dependencies in pyproject.toml for visualization and plotting functionality. Co-authored-by: Francois Lanusse --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 644821f..8e20217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pyyaml>=6.0", "arviz>=0.17.0", "matplotlib>=3.7.0", + "corner>=2.0.0", ] [project.optional-dependencies] From e909dd7c934abb5b0095f9c90b0a23bb887b7973 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 17:18:38 +0000 Subject: [PATCH 4/4] Move plot_shine_results.py to scripts/ directory Utility scripts should be kept separate from the package itself. Created scripts/ directory for utility scripts. Co-authored-by: Francois Lanusse --- {shine => scripts}/plot_shine_results.py | 126 +++++++++++------------ 1 file changed, 63 insertions(+), 63 deletions(-) rename {shine => scripts}/plot_shine_results.py (98%) mode change 100755 => 100644 diff --git a/shine/plot_shine_results.py b/scripts/plot_shine_results.py old mode 100755 new mode 100644 similarity index 98% rename from shine/plot_shine_results.py rename to scripts/plot_shine_results.py index 4e0a655..bdad11f --- a/shine/plot_shine_results.py +++ b/scripts/plot_shine_results.py @@ -56,7 +56,7 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da FileNotFoundError: If observation.npz or posterior.nc are not found in output_dir. """ output_path = Path(output_dir) - + # Load observation data obs_file = output_path / 'observation.npz' if not obs_file.exists(): @@ -64,7 +64,7 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da obs_data = np.load(obs_file) print(f"Observation data loaded from {obs_file}") print(f"Available keys: {list(obs_data.keys())}") - + # Load posterior estimates posterior_file = output_path / 'posterior.nc' if not posterior_file.exists(): @@ -74,7 +74,7 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da print(f"\nPosterior data loaded from {posterior_file}") print(f"Dataset structure:") print(posterior) - + return obs_data, idata, posterior @@ -92,13 +92,13 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None: print("\n" + "="*70) print("Plotting Observation") print("="*70) - + image = obs_data.get('image', None) psf = obs_data.get('psf', None) noise_map = obs_data.get('noise_map', None) - + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - + # Plot the galaxy image if image is not None: im1 = axes[0].imshow(image, origin='lower', cmap='viridis') @@ -109,7 +109,7 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None: axes[0].text(0.02, 0.98, f'Max: {image.max():.2e}\nMin: {image.min():.2e}', transform=axes[0].transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) - + # Plot the PSF if psf is not None: im2 = axes[1].imshow(psf, origin='lower', cmap='hot') @@ -117,7 +117,7 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None: axes[1].set_xlabel('X pixel') axes[1].set_ylabel('Y pixel') plt.colorbar(im2, ax=axes[1], label='Normalized Flux') - + # Plot the noise map if noise_map is not None: if noise_map.ndim == 0: # Scalar noise @@ -132,7 +132,7 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None: axes[2].set_xlabel('X pixel') axes[2].set_ylabel('Y pixel') plt.colorbar(im3, ax=axes[2], label='Noise σ') - + output_file = Path(output_dir) / 'observation_visual.png' plt.tight_layout() plt.savefig(output_file, dpi=150, bbox_inches='tight') @@ -155,45 +155,45 @@ def plot_posterior_distributions(posterior: xr.Dataset, param_names: List[str], print("\n" + "="*70) print("Plotting Posterior Distributions") print("="*70) - + n_params = len(param_names) n_cols = min(3, n_params) n_rows = int(np.ceil(n_params / n_cols)) - + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows)) if n_params == 1: axes = np.array([axes]) axes = axes.flatten() - + for idx, param in enumerate(param_names): ax = axes[idx] samples = posterior[param].values - + if samples.ndim > 1: samples = samples.flatten() - + ax.hist(samples, bins=50, density=True, alpha=0.7, color='steelblue', edgecolor='black') - + mean_val = np.mean(samples) median_val = np.median(samples) std_val = np.std(samples) - + ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}') ax.axvline(median_val, color='green', linestyle=':', linewidth=2, label=f'Median: {median_val:.4f}') - + ax.set_xlabel(f'{param}', fontsize=12) ax.set_ylabel('Density', fontsize=12) ax.set_title(f'{param} Posterior\nσ = {std_val:.4f}', fontsize=12, fontweight='bold') ax.legend(fontsize=9) ax.grid(True, alpha=0.3) - + # Hide empty subplots for idx in range(n_params, len(axes)): axes[idx].axis('off') - + output_file = Path(output_dir) / 'posterior_distributions.png' plt.tight_layout() plt.savefig(output_file, dpi=150, bbox_inches='tight') @@ -232,10 +232,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path) "The 'corner' package is required for corner plots. " "Install it with: pip install corner" ) - + # Prepare data: stack all parameters as columns samples_array = np.column_stack([posterior[param].values.flatten() for param in param_names]) - + # Create corner plot with confidence intervals fig = corner.corner( samples_array, @@ -252,10 +252,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path) truth_color='red', title_kwargs={"fontsize": 11}, ) - + plt.suptitle('Corner Plot: Joint & Marginal Distributions', fontsize=14, fontweight='bold', y=0.995) - + output_file = Path(output_dir) / 'corner_plot.png' plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() @@ -278,23 +278,23 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di pip install corner """ shear_params = [p for p in param_names if 'g1' in p.lower() or 'g2' in p.lower() or 'shear' in p.lower()] - + if not shear_params: print("\nNo shear parameters found. Skipping shear analysis.") return - + print("\n" + "="*70) print("Plotting Shear Analysis") print("="*70) print(f"Found shear parameters: {shear_params}") - + g1_param = next((p for p in param_names if 'g1' in p.lower()), None) g2_param = next((p for p in param_names if 'g2' in p.lower()), None) - + if not (g1_param and g2_param): print("Could not identify both g1 and g2 parameters. Skipping.") return - + # Import corner package try: import corner @@ -303,19 +303,19 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di "The 'corner' package is required for shear analysis plots. " "Install it with: pip install corner" ) - + g1_samples = posterior[g1_param].values.flatten() g2_samples = posterior[g2_param].values.flatten() - + # Prepare data: stack g1 and g2 as columns samples_array = np.column_stack([g1_samples, g2_samples]) - + # Calculate statistics g1_mean = np.mean(g1_samples) g1_std = np.std(g1_samples) g2_mean = np.mean(g2_samples) g2_std = np.std(g2_samples) - + # Create corner plot for shear parameters only fig = corner.corner( samples_array, @@ -332,14 +332,14 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di truth_color='red', title_kwargs={"fontsize": 11}, ) - - plt.suptitle('Shear Parameters Corner Plot (g1, g2)', + + plt.suptitle('Shear Parameters Corner Plot (g1, g2)', fontsize=14, fontweight='bold', y=0.995) - + output_file = Path(output_dir) / 'shear_analysis.png' plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() - + print(f"\nShear estimates:") print(f" g1 = {g1_mean:.6f} ± {g1_std:.6f}") print(f" g2 = {g2_mean:.6f} ± {g2_std:.6f}") @@ -361,7 +361,7 @@ def print_summary_statistics(posterior: xr.Dataset, param_names: List[str]) -> N print("="*70) print(f"{'Parameter':<20} {'Mean':<12} {'Std':<12} {'Median':<12} {'95% CI':<20}") print("-"*70) - + for param in param_names: samples = posterior[param].values.flatten() mean_val = np.mean(samples) @@ -369,9 +369,9 @@ def print_summary_statistics(posterior: xr.Dataset, param_names: List[str]) -> N median_val = np.median(samples) ci_low = np.percentile(samples, 2.5) ci_high = np.percentile(samples, 97.5) - + print(f"{param:<20} {mean_val:<12.6f} {std_val:<12.6f} {median_val:<12.6f} [{ci_low:.6f}, {ci_high:.6f}]") - + print("="*70) @@ -389,39 +389,39 @@ def plot_trace(posterior: xr.Dataset, param_names: List[str], output_dir: Path) output_dir (Path): Directory where the plot will be saved. """ has_chains = any(dim in posterior.dims for dim in ['chain', 'draw', 'sample']) - + if not has_chains: print("\nNo chain/draw dimensions found - likely MAP or point estimate.") print("Skipping trace plots.") return - + print("\n" + "="*70) print("Plotting Trace Plots") print("="*70) - + n_params = len(param_names) fig, axes = plt.subplots(n_params, 1, figsize=(12, 3*n_params)) - + if n_params == 1: axes = [axes] - + for idx, param in enumerate(param_names): samples = posterior[param].values - + # Trace plot if samples.ndim >= 2: for chain in range(samples.shape[0]): axes[idx].plot(samples[chain], alpha=0.7, label=f'Chain {chain}') else: axes[idx].plot(samples, alpha=0.7) - + axes[idx].set_ylabel(param, fontsize=11) axes[idx].set_xlabel('Iteration', fontsize=11) axes[idx].set_title(f'{param} - Trace', fontsize=12, fontweight='bold') axes[idx].grid(True, alpha=0.3) if samples.ndim >= 2 and samples.shape[0] <= 10: axes[idx].legend(fontsize=8) - + output_file = Path(output_dir) / 'trace_plots.png' plt.tight_layout() plt.savefig(output_file, dpi=150, bbox_inches='tight') @@ -444,37 +444,37 @@ def plot_correlation_matrix(posterior: xr.Dataset, param_names: List[str], outpu if len(param_names) <= 1: print("\nOnly one parameter - skipping correlation matrix.") return - + print("\n" + "="*70) print("Plotting Correlation Matrix") print("="*70) - + # Create correlation matrix data_matrix = np.column_stack([posterior[param].values.flatten() for param in param_names]) corr_matrix = np.corrcoef(data_matrix.T) - + # Plot correlation matrix fig, ax = plt.subplots(figsize=(10, 8)) im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto') - + # Set ticks ax.set_xticks(range(len(param_names))) ax.set_yticks(range(len(param_names))) ax.set_xticklabels(param_names, rotation=45, ha='right') ax.set_yticklabels(param_names) - + # Add colorbar cbar = plt.colorbar(im, ax=ax) cbar.set_label('Correlation', fontsize=12) - + # Add correlation values for i in range(len(param_names)): for j in range(len(param_names)): ax.text(j, i, f'{corr_matrix[i, j]:.2f}', ha="center", va="center", color="black", fontsize=10) - + ax.set_title('Parameter Correlation Matrix', fontsize=14, fontweight='bold', pad=20) - + output_file = Path(output_dir) / 'correlation_matrix.png' plt.tight_layout() plt.savefig(output_file, dpi=150, bbox_inches='tight') @@ -505,34 +505,34 @@ def main() -> None: required=True, help='Directory containing observation.npz and posterior.nc files (and where plots will be saved)' ) - + args = parser.parse_args() - + # Verify output directory exists output_dir = Path(args.output) if not output_dir.exists(): print(f"Error: Directory {output_dir} does not exist") sys.exit(1) - + print("="*70) print("SHINE RESULTS VISUALIZATION") print("="*70) print(f"Output directory: {output_dir.absolute()}") - + # Setup plotting style setup_plot_style() - + # Load data try: obs_data, idata, posterior = load_data(output_dir) except FileNotFoundError as e: print(f"\nError: {e}") sys.exit(1) - + # Get parameter names param_names = list(posterior.data_vars) print(f"\nInferred parameters: {param_names}") - + # Generate all plots plot_observation(obs_data, output_dir) plot_posterior_distributions(posterior, param_names, output_dir) @@ -541,7 +541,7 @@ def main() -> None: print_summary_statistics(posterior, param_names) plot_trace(posterior, param_names, output_dir) plot_correlation_matrix(posterior, param_names, output_dir) - + # Final summary print("\n" + "="*70) print("VISUALIZATION COMPLETE")