This document details the software architecture, design principles, and technical specifications for SHINE (SHear INference Environment). It serves as the primary reference for developers and contributors.
SHINE is a fully differentiable, forward-modeling pipeline for gravitational shear inference. It leverages JAX for automatic differentiation and GPU acceleration, numpyro for probabilistic programming, and JAX-GalSim for astronomical image simulation. The software is designed to handle data from major surveys (Euclid, LSST, MeerKAT) and aims to overcome limitations of traditional shear measurement methods (e.g., metacalibration) by directly inferring shear from pixel data within a Bayesian framework.
SHINE treats shear measurement as a Bayesian inverse problem. Instead of measuring ellipticities and correcting for biases, SHINE generates forward models of the sky, convolved with the instrument response, and compares them to observed data to infer the posterior distribution of shear parameters.
- JAX: The backbone for computation, providing
jitcompilation,vmapvectorization, andgradfor HMC inference. - NumPyro: Provides the probabilistic programming framework for defining hierarchical models and performing MCMC inference (NUTS/HMC).
- JAX-GalSim: A JAX port of GalSim used for differentiable galaxy profile rendering and PSF convolution.
- BlackJAX: (Optional/Alternative) Lower-level inference library if custom samplers are needed beyond NumPyro's offerings.
graph TD
Config[YAML Configuration] --> Loader[Config Handler]
Loader --> Scene[Scene Modelling<br/>NumPyro + JAX-GalSim]
subgraph "Forward Model (JAX)"
Priors[Priors<br/>NumPyro] --> Scene
Scene --> Galaxy[Galaxy Generation<br/>Sersic/Morphology]
Scene --> PSF[PSF Modelling]
Galaxy --> Convolve[Convolution]
PSF --> Convolve
Convolve --> Noise[Noise Model]
Noise --> ModelImage[Simulated Image]
end
Data[Observed Data<br/>Fits/HDF5] --> Likelihood
ModelImage --> Likelihood[Likelihood Evaluation]
Likelihood --> Inference[Inference Engine<br/>NumPyro/BlackJAX]
Inference --> Posterior[Shear Posterior]
subgraph "Workflow Management"
WMS[WMS<br/>Slurm/Cluster] --> Config
WMS --> Data
end
Responsibility: Parse YAML configuration files and validate inputs. Implementation:
- Uses
pydanticor standardyamllibraries for robust validation. - Converts physical units (e.g., arcsec to radians) to internal consistency.
- Key Config Sections:
Scene,PSF,Survey(Euclid/LSST/MeerKAT),Inference,Compute. - GalSim Compatibility: Supports standard GalSim YAML structure for defining scene components. Any parameter defined as a distribution (e.g.,
Normal,LogNormal) is automatically treated as a latent variable for inference.
Responsibility: Define the generative model of the astronomical scene. Integration:
- NumPyro Integration: The core function is a
numpyromodel. It defines random variables for galaxy properties (flux, size, ellipticity, position) and global parameters (shear, PSF properties). - Hierarchical Models: Supports partial pooling for galaxy populations (e.g., drawing individual galaxy Sersic indices from a population-level distribution).
- JAX-GalSim Usage: Inside the model,
jax_galsimobjects (Sersic,Gaussian,Shear) are created using the sampled parameters.- Limitation Handling: Since
jax_galsimobjects are immutable, the pipeline must be purely functional. No in-place modifications. - Vectorization: Uses
jax.vmapornumpyro.plateto efficiently render thousands of galaxies.
- Limitation Handling: Since
Code Structure Example:
def model(data, config):
# Global Shear
g1 = numpyro.sample("g1", dist.Normal(0, 0.05))
g2 = numpyro.sample("g2", dist.Normal(0, 0.05))
shear = jax_galsim.Shear(g1=g1, g2=g2)
# Plate for N galaxies
with numpyro.plate("galaxies", config.n_galaxies):
flux = numpyro.sample("flux", dist.LogNormal(...))
# ... other params ...
# Deterministic transformation (rendering)
# Note: This needs careful vectorization with JAX-GalSim
image = render_scene(flux, ..., shear)
numpyro.sample("obs", dist.Normal(image, sigma), obs=data)Responsibility: Perform Bayesian inference to obtain posteriors. Methods:
- NUTS (No-U-Turn Sampler): Default high-performance sampler provided by NumPyro.
- SVI (Stochastic Variational Inference): For faster, approximate results on large datasets.
- BlackJAX Integration: Wrappers for BlackJAX samplers if specific HMC variations are required. Optimization:
- Uses
jax.jitto compile the likelihood and gradient functions. - Supports reparameterization (e.g.,
LocScaleReparam) to improve geometry for hierarchical models.
Responsibility: Standardize input data from various sources. Modules:
euclid: Handles Euclid-specific PSF models and pixel scales.lsst: Handles LSST bands and observing conditions.radio: Handles visibility data or dirty images (if operating in image plane) for MeerKAT. Abstraction:- All modules return a standardized
Observationobject containing:image_data,noise_map,psf_model,wcs.
Responsibility: Generate galaxy surface brightness profiles. Types:
- Parametric: Standard Sersic profiles via
jax_galsim. - Non-Parametric:
- VAE/GAN: Generative models trained on high-resolution data (e.g., COSMOS) to produce realistic morphologies.
- Basis Functions: Shapelets or other basis sets if supported by JAX-GalSim.
Responsibility: Orchestrate large-scale runs on HPC clusters. Features:
- Patching: Splits large survey fields into manageable patches.
- Job Submission: Generates SLURM scripts.
- Monitoring: Tracks job status and aggregates results.
- Style: Black, isort.
- Typing: Full type hinting (PEP 484).
- Documentation: Google-style docstrings, Sphinx + ReadTheDocs.
- Testing:
pytestfor unit tests.chexfor JAX-specific testing (array shapes, types).
- RNG: JAX uses a functional PRNG.
jax_galsim's RNG handling must be carefully managed to ensure reproducibility, especially withinnumpyromodels which manage their own RNG keys. - Performance: Avoid Python loops in the rendering path. Use
vmaporscan.
- Unit Tests: Verify individual components (e.g., Sersic profile generation matches analytical expectations).
- Integration Tests: End-to-end run on a small synthetic patch.
- Validation:
- Self-Consistency: Generate data with known shear, infer it back, check if truth is within posterior credible intervals.
- Comparison: Compare results with standard GalSim (non-JAX) for identical inputs to ensure numerical accuracy.
- Phase 1: Prototype with simple parametric models (Sersic) and constant PSF.
- Phase 2: Integration of realistic PSF models (Euclid/LSST specific).
- Phase 3: Non-parametric galaxy morphology (VAE/Diffusion).
- Phase 4: Large-scale validation on Flagship/CosmoDC2 simulations.
This section demonstrates the complete workflow from configuration to inference, highlighting the interaction between shine components.
SHINE adopts the GalSim YAML structure but interprets it in a probabilistic context.
Note
Standard GalSim vs. SHINE Extensions:
- Standard:
image,psf,galtop-level keys. Use oftypeto specify profiles (e.g.,Sersic,Gaussian). - SHINE Extension: In standard GalSim, defining a random distribution (e.g.,
type: Normal) draws a single fixed value for that simulation. In SHINE, this defines a prior distribution for a latent variable that will be inferred.
Example Config (configs/euclid_run.yaml):
# Top-level simulation parameters (Standard GalSim)
image:
pixel_scale: 0.1 # arcsec/pixel
size_x: 64
size_y: 64
noise:
type: Gaussian
sigma: 1.0
# PSF Definition (Fixed for this example)
psf:
type: Gaussian
sigma: 0.1
# Galaxy Definition (Probabilistic)
gal:
type: Sersic
# Fixed parameter (Standard GalSim behavior)
n: 4.0
# Probabilistic parameters (SHINE Interpretation: Latent Variables)
flux:
type: LogNormal
mean: 1000.0
sigma: 0.5
half_light_radius:
type: Uniform
min: 0.1
max: 1.0
shear:
type: G1G2
g1:
type: Normal
mean: 0.0
sigma: 0.05
g2:
type: Normal
mean: 0.0
sigma: 0.05# shine/main_example.py
import jax
from shine.config import ConfigHandler
from shine.scene import SceneBuilder
from shine.inference import HMCInference
from shine.data import DataLoader
def main():
# 1. Load Configuration
# Parses YAML, validates types, and converts units (e.g., arcsec -> radians)
config = ConfigHandler.load("configs/euclid_run.yaml")
# 2. Load Data
# Returns a standardized Observation object containing:
# - image: jax.numpy array of pixel data
# - noise_map: variance map
# - psf_model: JAX-GalSim PSF object or kernel
# - wcs: World Coordinate System info
observation = DataLoader.load(config.data_path)
# 3. Build the Probabilistic Model
# SceneBuilder initializes the NumPyro model function.
# It pre-compiles JAX-GalSim components based on config.
scene_builder = SceneBuilder(config)
model_fn = scene_builder.build_model()
# 4. Initialize Inference Engine
# Sets up the NUTS sampler with specified parameters
inference_engine = HMCInference(
model=model_fn,
num_warmup=config.inference.warmup,
num_samples=config.inference.samples,
num_chains=config.inference.chains,
dense_mass=config.inference.dense_mass # Use dense mass matrix if correlated
)
# 5. Run Inference
# Executes the MCMC chain.
# Uses JAX's PRNGKey for reproducibility.
print("Starting inference...")
rng_key = jax.random.PRNGKey(42)
results = inference_engine.run(
rng_key=rng_key,
observed_data=observation.image,
extra_args={"psf": observation.psf_model}
)
# 6. Analyze and Save Results
# Results object contains posterior samples and convergence diagnostics (r_hat)
print(f"Inferred Shear g1: {results.samples['g1'].mean():.5f} ± {results.samples['g1'].std():.5f}")
print(f"Inferred Shear g2: {results.samples['g2'].mean():.5f} ± {results.samples['g2'].std():.5f}")
results.save(config.output_path / "posterior.nc") # Save as NetCDF/ArviZ compatible
if __name__ == "__main__":
main()The SceneBuilder constructs the numpyro model by parsing the GalSim-style YAML.
# Inside shine/scene/builder.py
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import jax_galsim as galsim
class SceneBuilder:
def __init__(self, config):
self.config = config
def _parse_prior(self, name, param_config):
"""Helper to create NumPyro distributions from config."""
if param_config['type'] == 'Normal':
return numpyro.sample(name, dist.Normal(param_config['mean'], param_config['sigma']))
elif param_config['type'] == 'LogNormal':
return numpyro.sample(name, dist.LogNormal(jnp.log(param_config['mean']), param_config['sigma']))
# ... other distributions ...
def build_model(self):
"""Returns a callable model function for NumPyro."""
def model(observed_data=None, psf=None):
# --- 1. Global Parameters (Shear) ---
# Parsed from config.gal.shear
g1 = self._parse_prior("g1", self.config.gal.shear.g1)
g2 = self._parse_prior("g2", self.config.gal.shear.g2)
shear = galsim.Shear(g1=g1, g2=g2)
# --- 2. Galaxy Population (Hierarchical) ---
# We use a plate to assume independence between galaxies given global params
n_galaxies = self.config.image.n_objects # Standard GalSim field
with numpyro.plate("galaxies", n_galaxies):
# Sample morphological parameters from config
flux = self._parse_prior("flux", self.config.gal.flux)
hlr = self._parse_prior("hlr", self.config.gal.half_light_radius)
# Fixed parameter example
n = self.config.gal.n
# Sample positions
x = numpyro.sample("x", dist.Uniform(0, self.config.image.size_x))
y = numpyro.sample("y", dist.Uniform(0, self.config.image.size_y))
# --- 3. Differentiable Rendering (JAX-GalSim) ---
# We define a helper to render a single galaxy
def render_one_galaxy(flux, hlr, n, x, y):
gal = galsim.Sersic(n=n, half_light_radius=hlr, flux=flux)
gal = gal.shear(shear) # Gravitational
gal = galsim.Convolve([gal, psf])
# Draw onto a stamp or full image (simplified here)
return gal.drawImage(nx=self.config.image.size_x,
ny=self.config.image.size_y,
scale=self.config.image.pixel_scale,
offset=(x, y)).array
# Vectorize rendering over all galaxies using vmap
# This is crucial for performance in JAX
galaxy_images = jax.vmap(render_one_galaxy)(flux, hlr, n, x, y)
# Sum all galaxy images to get the model scene
model_image = jnp.sum(galaxy_images, axis=0)
# --- 4. Likelihood ---
# Compare model image to observed data
sigma = self.config.image.noise.sigma
numpyro.sample("obs", dist.Normal(model_image, sigma), obs=observed_data)
return model