Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
57e08eb
add option to treat input systematic histograms as difference with re…
bendavid Apr 6, 2026
c80b409
add test for sparse mode
bendavid Apr 6, 2026
0e17ff6
Support scipy sparse array inputs in TensorWriter and add as_differen…
bendavid Apr 6, 2026
91ce1b3
Add multi-systematic dispatch in add_systematic and use wums.SparseHist
bendavid Apr 7, 2026
dab00bc
Add external likelihood term (gradient + hessian) support
bendavid Apr 7, 2026
47f1f90
Add efficient SparseHist multi-systematic dispatch in TensorWriter
bendavid Apr 7, 2026
08537b8
Speed up TensorWriter for large multi-systematic SparseHist workloads
bendavid Apr 7, 2026
f51e53f
inputdata, parsing: prep for sparse fast path with CSR matvec
bendavid Apr 9, 2026
83afbd8
fitter: dynamic loss/grad/HVP wrappers with jit_compile + hvpMethod
bendavid Apr 9, 2026
b6d7120
fitter: sparse fast path uses CSR matmul, no dense [nbins,nproc]
bendavid Apr 9, 2026
b30f867
fitter: external sparse Hessian via CSR matmul
bendavid Apr 9, 2026
3f41fc6
rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU
bendavid Apr 9, 2026
183f376
fitter, rabbit_fit: skip dense cov allocation under --noHessian
bendavid Apr 9, 2026
8493332
fitter: speed up Fitter.__init__ on large external sparse Hessians
bendavid Apr 9, 2026
6c1c187
unify sparse-Hessian IO path; sort at write time, drop reorder calls
bendavid Apr 9, 2026
db274db
fitter: Hessian-free CG solve for is_linear case under --noHessian
bendavid Apr 9, 2026
d0708b1
fitter, rabbit_fit: edmval + POI/NOI uncertainties under --noHessian
bendavid Apr 10, 2026
fad47bc
external_likelihood: factor out external-term IO + tf build + nll eval
bendavid Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#!/usr/bin/env python3

# Enable XLA's multi-threaded Eigen path on CPU before importing tensorflow.
# This must be set before any TF import (including transitive) because XLA
# parses XLA_FLAGS once during runtime initialization. Measured ~1.3x speedup
# on dense large-model HVP/loss+grad on a many-core system, no downside.
# Users who set their own XLA_FLAGS keep theirs and we append.
import os as _os

_xla_default = "--xla_cpu_multi_thread_eigen=true"
_existing = _os.environ.get("XLA_FLAGS", "")
if "xla_cpu_multi_thread_eigen" not in _existing:
_os.environ["XLA_FLAGS"] = (
f"{_existing} {_xla_default}".strip() if _existing else _xla_default
)

import copy

import tensorflow as tf
Expand Down Expand Up @@ -349,9 +363,16 @@ def save_hists(args, mappings, fitter, ws, prefit=True, profile=False):
)

if args.computeVariations:
if fitter.cov is None:
raise RuntimeError(
"--computeVariations requires the parameter covariance "
"matrix and so is incompatible with --noHessian."
)
if prefit:
cov_prefit = fitter.cov.numpy()
fitter.cov.assign(fitter.prefit_covariance(unconstrained_err=1.0))
fitter.cov.assign(
fitter.prefit_covariance(unconstrained_err=1.0).to_dense()
)

exp, aux = fitter.expected_events(
mapping,
Expand Down Expand Up @@ -407,6 +428,9 @@ def fit(args, fitter, ws, dofit=True):
ws.add_1D_integer_hist(cb.loss_history, "epoch", "loss")
ws.add_1D_integer_hist(cb.time_history, "epoch", "time")

# prefit variances as the default fallback for add_parms_hist below
parms_variances = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this is needed, it follows an if ... else statement in the code where parms_variances is defined either way


if not args.noHessian:
# compute the covariance matrix and estimated distance to minimum
_, grad, hess = fitter.loss_val_grad_hess()
Expand Down Expand Up @@ -446,6 +470,31 @@ def fit(args, fitter, ws, dofit=True):
global_impacts=True,
)

parms_variances = tf.linalg.diag_part(fitter.cov)
else:
# --noHessian: avoid the full dense Hessian. Still compute edmval
# and the POI+NOI uncertainties via a Hessian-free conjugate
# gradient solve of H @ v = grad and H @ c_i = e_i, using only
# Hessian-vector products. The CG solves touch O(npar) memory
# per call instead of O(npar^2), so this works on problems
# where the full covariance would be infeasible.
_, grad = fitter.loss_val_grad()
npoi = int(fitter.poi_model.npoi)
noi_idx_in_x = np.asarray(fitter.indata.noiidxs, dtype=np.int64) + npoi
poi_noi_idx = np.concatenate([np.arange(npoi, dtype=np.int64), noi_idx_in_x])
edmval, cov_rows = fitter.edmval_cov_rows_hessfree(grad, poi_noi_idx)
logger.info(f"edmval: {edmval}")

# Build a full-length variance vector with the POI+NOI entries
# populated from the diagonal of the CG-solved rows and the rest
# left as NaN (we did not compute those). add_parms_hist stores
# the vector verbatim into the workspace.
n = int(fitter.x.shape[0])
parms_variances_np = np.full(n, np.nan, dtype=np.float64)
for k, i in enumerate(poi_noi_idx):
parms_variances_np[int(i)] = cov_rows[k, int(i)]
parms_variances = tf.constant(parms_variances_np, dtype=fitter.indata.dtype)

nllvalreduced = fitter.reduced_nll().numpy()

ndfsat = (
Expand Down Expand Up @@ -476,7 +525,7 @@ def fit(args, fitter, ws, dofit=True):

ws.add_parms_hist(
values=fitter.x,
variances=tf.linalg.diag_part(fitter.cov) if not args.noHessian else None,
variances=parms_variances,
hist_name="parms",
)

Expand Down Expand Up @@ -560,8 +609,31 @@ def main():
if args.eager:
tf.config.run_functions_eagerly(True)

if args.noHessian and args.doImpacts:
raise Exception('option "--noHessian" only works without "--doImpacts"')
# --noHessian skips computing the postfit Hessian, so the dense
# parameter covariance matrix is never available. Any feature that
# needs the covariance is incompatible.
if args.noHessian:
_incompat = []
if args.doImpacts:
_incompat.append("--doImpacts")
if args.computeVariations:
_incompat.append("--computeVariations")
if args.saveHists and not args.noChi2:
_incompat.append("--saveHists (without --noChi2)")
if args.computeHistErrors:
_incompat.append("--computeHistErrors")
if args.computeHistErrorsPerProcess:
_incompat.append("--computeHistErrorsPerProcess")
if args.computeHistCov:
_incompat.append("--computeHistCov")
if args.computeHistImpacts:
_incompat.append("--computeHistImpacts")
if args.computeHistGaussianImpacts:
_incompat.append("--computeHistGaussianImpacts")
if args.externalPostfit is not None:
_incompat.append("--externalPostfit")
if _incompat:
raise Exception("--noHessian is incompatible with: " + ", ".join(_incompat))

global logger
logger = logging.setup_logger(__file__, args.verbose, args.noColorLogger)
Expand Down Expand Up @@ -679,7 +751,7 @@ def main():

ws.add_parms_hist(
values=ifitter.x,
variances=tf.linalg.diag_part(ifitter.cov),
variances=ifitter.var_prefit,
hist_name="parms_prefit",
)

Expand Down
224 changes: 224 additions & 0 deletions rabbit/external_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""Helpers for external likelihood terms (linear + quadratic parameter priors).

An "external likelihood term" is an additive contribution to the NLL of
the form

-log L_ext = g^T x_sub + 0.5 * x_sub^T H x_sub

where ``x_sub`` is the subset of the fit parameters the term constrains.
Both the linear (``grad``) and quadratic (``hess_dense`` / ``hess_sparse``)
parts are optional; the sparse Hessian is stored as a
``tf.sparse.SparseTensor`` whose indices are in canonical row-major order.

This module centralizes three things that were previously inlined in
``Fitter.__init__``, ``Fitter._compute_external_nll``, and
``FitInputData.__init__``:

* :func:`read_external_terms_from_h5` — load the raw numpy-level
per-term dicts from an HDF5 group (used by FitInputData)
* :func:`build_tf_external_terms` — turn that list into tf-side per-term
dicts (resolved parameter indices, tf.constant grads, CSRSparseMatrix
Hessians). Used by the Fitter when it takes ownership of the input
data.
* :func:`compute_external_nll` — evaluate the scalar NLL contribution
of a list of tf-side terms at the current ``x``.
"""

import numpy as np
import tensorflow as tf
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops as tf_sparse_csr

from rabbit.h5pyutils_read import makesparsetensor, maketensor


def read_external_terms_from_h5(ext_group):
"""Decode an HDF5 ``external_terms`` group into a list of raw dicts.

Each entry has the keys used by the rest of the pipeline:

* ``name``: term label (str, taken from the h5 subgroup name)
* ``params``: 1D ndarray of parameter name strings
* ``grad_values``: 1D float ndarray or ``None``
* ``hess_dense``: 2D float ndarray or ``None``
* ``hess_sparse``: :class:`tf.sparse.SparseTensor` or ``None`` (uses
the same on-disk layout as ``hlogk_sparse`` / ``hnorm_sparse``)

Parameters
----------
ext_group : h5py.Group
The ``external_terms`` group in the input HDF5 file, or ``None``.

Returns
-------
list[dict]
One entry per stored external term, or an empty list if
``ext_group`` is ``None``.
"""
if ext_group is None:
return []

terms = []
for tname, tg in ext_group.items():
raw_params = tg["params"][...]
params = np.array(
[s.decode() if isinstance(s, bytes) else s for s in raw_params]
)
grad_values = (
np.asarray(maketensor(tg["grad_values"]))
if "grad_values" in tg.keys()
else None
)
hess_dense = (
np.asarray(maketensor(tg["hess_dense"]))
if "hess_dense" in tg.keys()
else None
)
hess_sparse = (
makesparsetensor(tg["hess_sparse"]) if "hess_sparse" in tg.keys() else None
)
terms.append(
{
"name": tname,
"params": params,
"grad_values": grad_values,
"hess_dense": hess_dense,
"hess_sparse": hess_sparse,
}
)
return terms


def build_tf_external_terms(terms, parms, dtype):
"""Turn raw external-term dicts into tf-side dicts ready for the fitter.

* Parameter names are resolved against the full fit parameter list
``parms`` via a single ``name->index`` dict (O(n) rather than the
naive O(n^2) per-parameter ``np.where`` that this replaces — the
latter cost ~150 s on a 108k-parameter setup with a 108k-parameter
external term).
* Gradients are promoted to ``tf.constant`` in the fitter dtype.
* Dense Hessians are promoted to ``tf.constant``.
* Sparse Hessians are promoted to a :class:`CSRSparseMatrix` view
for fast ``sm.matmul``.

Parameters
----------
terms : list[dict]
Raw per-term dicts as returned by :func:`read_external_terms_from_h5`.
parms : array-like of str
Full ordered list of fit parameter names (POIs + systematics).
dtype : tf.DType
Fitter dtype for gradient / Hessian tensors.

Returns
-------
list[dict]
One entry per term with keys ``name``, ``indices``, ``grad``,
``hess_dense``, ``hess_csr``. Empty if ``terms`` is empty.
"""
parms_str = np.asarray(parms).astype(str)
parms_idx = {name: i for i, name in enumerate(parms_str)}
if len(parms_idx) != len(parms_str):
raise RuntimeError(
"Duplicate parameter names in fitter parameter list; "
"external term resolution requires unique names."
)

out = []
for term in terms:
params = np.asarray(term["params"]).astype(str)
indices = np.empty(len(params), dtype=np.int64)
for i, p in enumerate(params):
j = parms_idx.get(p, -1)
if j < 0:
raise RuntimeError(
f"External likelihood term '{term['name']}' parameter "
f"'{p}' not found in fit parameters"
)
indices[i] = j
tf_indices = tf.constant(indices, dtype=tf.int64)

tf_grad = (
tf.constant(term["grad_values"], dtype=dtype)
if term["grad_values"] is not None
else None
)

tf_hess_dense = None
tf_hess_csr = None
if term["hess_dense"] is not None:
tf_hess_dense = tf.constant(term["hess_dense"], dtype=dtype)
elif term["hess_sparse"] is not None:
# Build a CSRSparseMatrix view of the stored sparse Hessian
# for use in the closed-form external gradient/HVP path via
# sm.matmul. The Hessian is assumed symmetric, so the loss
# L = 0.5 x_sub^T H x_sub has gradient H @ x_sub and HVP
# H @ p_sub, each a single sm.matmul call. NOTE:
# SparseMatrixMatMul has no XLA kernel, so any tf.function
# that calls sm.matmul must be built with jit_compile=False.
# The TensorWriter sorts the indices into canonical row-major
# order at write time, so we can feed the SparseTensor
# straight to the CSR builder without an additional reorder
# step.
tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(term["hess_sparse"])

out.append(
{
"name": term["name"],
"indices": tf_indices,
"grad": tf_grad,
"hess_dense": tf_hess_dense,
"hess_csr": tf_hess_csr,
}
)
return out


def compute_external_nll(terms, x, dtype):
"""Evaluate the scalar NLL contribution of a list of external terms.

For each term, adds ``g^T x_sub + 0.5 * x_sub^T H x_sub`` to the
running total. Sparse Hessian terms use ``sm.matmul`` for the
``H @ x_sub`` product, which dispatches to a multi-threaded CSR
kernel and is much faster per call than the previous element-wise
gather-based form. The autodiff gradient and HVP of
``0.5 x^T H x`` via ``sm.matmul`` are themselves single
``sm.matmul`` calls, so reverse-over-reverse autodiff no longer
rematerializes a 2D gather/scatter chain in the second-order tape
— that was the dominant cost on large external-Hessian problems
(e.g. jpsi: 329M-nnz prefit Hessian).

Parameters
----------
terms : list[dict]
tf-side per-term dicts as returned by :func:`build_tf_external_terms`.
x : tf.Tensor
Current full parameter vector.
dtype : tf.DType
Dtype for the accumulator.

Returns
-------
tf.Tensor or None
Scalar contribution to the NLL, or ``None`` if ``terms`` is empty.
"""
if not terms:
return None
total = tf.zeros([], dtype=dtype)
for term in terms:
x_sub = tf.gather(x, term["indices"])
if term["grad"] is not None:
total = total + tf.reduce_sum(term["grad"] * x_sub)
if term["hess_dense"] is not None:
# 0.5 * x_sub^T H x_sub
total = total + 0.5 * tf.reduce_sum(
x_sub * tf.linalg.matvec(term["hess_dense"], x_sub)
)
elif term["hess_csr"] is not None:
# Loss = 0.5 * x_sub^T H x_sub via CSR matvec (H symmetric).
Hx = tf.squeeze(
tf_sparse_csr.matmul(term["hess_csr"], x_sub[:, None]),
axis=-1,
)
total = total + 0.5 * tf.reduce_sum(x_sub * Hx)
return total
Loading
Loading