Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
57e08eb
add option to treat input systematic histograms as difference with re…
bendavid Apr 6, 2026
c80b409
add test for sparse mode
bendavid Apr 6, 2026
0e17ff6
Support scipy sparse array inputs in TensorWriter and add as_differen…
bendavid Apr 6, 2026
91ce1b3
Add multi-systematic dispatch in add_systematic and use wums.SparseHist
bendavid Apr 7, 2026
dab00bc
Add external likelihood term (gradient + hessian) support
bendavid Apr 7, 2026
47f1f90
Add efficient SparseHist multi-systematic dispatch in TensorWriter
bendavid Apr 7, 2026
08537b8
Speed up TensorWriter for large multi-systematic SparseHist workloads
bendavid Apr 7, 2026
f51e53f
inputdata, parsing: prep for sparse fast path with CSR matvec
bendavid Apr 9, 2026
83afbd8
fitter: dynamic loss/grad/HVP wrappers with jit_compile + hvpMethod
bendavid Apr 9, 2026
b6d7120
fitter: sparse fast path uses CSR matmul, no dense [nbins,nproc]
bendavid Apr 9, 2026
b30f867
fitter: external sparse Hessian via CSR matmul
bendavid Apr 9, 2026
3f41fc6
rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU
bendavid Apr 9, 2026
183f376
fitter, rabbit_fit: skip dense cov allocation under --noHessian
bendavid Apr 9, 2026
8493332
fitter: speed up Fitter.__init__ on large external sparse Hessians
bendavid Apr 9, 2026
6c1c187
unify sparse-Hessian IO path; sort at write time, drop reorder calls
bendavid Apr 9, 2026
db274db
fitter: Hessian-free CG solve for is_linear case under --noHessian
bendavid Apr 9, 2026
d0708b1
fitter, rabbit_fit: edmval + POI/NOI uncertainties under --noHessian
bendavid Apr 10, 2026
fad47bc
external_likelihood: factor out external-term IO + tf build + nll eval
bendavid Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#!/usr/bin/env python3

# Enable XLA's multi-threaded Eigen path on CPU before importing tensorflow.
# This must be set before any TF import (including transitive) because XLA
# parses XLA_FLAGS once during runtime initialization. Measured ~1.3x speedup
# on dense large-model HVP/loss+grad on a many-core system, no downside.
# Users who set their own XLA_FLAGS keep theirs and we append.
import os as _os

_xla_default = "--xla_cpu_multi_thread_eigen=true"
_existing = _os.environ.get("XLA_FLAGS", "")
if "xla_cpu_multi_thread_eigen" not in _existing:
_os.environ["XLA_FLAGS"] = (
f"{_existing} {_xla_default}".strip() if _existing else _xla_default
)

import copy

import tensorflow as tf
Expand Down
283 changes: 232 additions & 51 deletions rabbit/fitter.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions rabbit/h5pyutils_read.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hdf5plugin # noqa: F401 registers Blosc2/LZ4 filter used by the writer
import tensorflow as tf


Expand Down
48 changes: 39 additions & 9 deletions rabbit/h5pyutils_write.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,53 @@
import math

import hdf5plugin
import numpy as np

# Compression strategy for the HDF5 write path.
#
# By default dense arrays are written with Blosc2 + LZ4 byte-shuffle. This is
# much faster than gzip (typically ~5x on write) while achieving equal or
# better ratios, and works well for dense tensor buffers that often contain
# lots of structural zeros from sparsity patterns.
#
# Callers that know the data is already densely packed with unstructured
# nonzero values (e.g. the ``values`` payload of an explicitly sparse tensor)
# can pass ``compress=False`` to skip compression entirely. For those inputs
# Blosc2 LZ4 buys only ~4-5% of file size at ~5x the write cost, so turning
# it off is a strict win.
#
# The HDF5 filter pipeline is fundamentally single-threaded per chunk, so
# multi-threaded compression via BLOSC2_NTHREADS does not take effect through
# h5py; the main speedup comes from switching compressor and skipping the
# uncompressible buffers.
#
# Reading requires the hdf5plugin filter to be registered, which happens
# automatically via the ``import hdf5plugin`` at module import time in both
# h5pyutils_write and h5pyutils_read.
_DEFAULT_COMPRESSION_KWARGS = hdf5plugin.Blosc2(cname="lz4", clevel=5)

def writeFlatInChunks(arr, h5group, outname, maxChunkBytes=1024**2):

def writeFlatInChunks(arr, h5group, outname, maxChunkBytes=1024**2, compress=True):
arrflat = arr.reshape(-1)

esize = np.dtype(arrflat.dtype).itemsize
nbytes = arrflat.size * esize

# special handling for empty datasets, which should not use chunked storage or compression
# Empty datasets must not use chunked storage or compression.
if arrflat.size == 0:
chunksize = 1
chunks = None
compression = None
extra_kwargs = {"chunks": None}
else:
chunksize = int(min(arrflat.size, max(1, math.floor(maxChunkBytes / esize))))
chunks = (chunksize,)
compression = "gzip"
extra_kwargs = {"chunks": (chunksize,)}
if compress:
extra_kwargs.update(_DEFAULT_COMPRESSION_KWARGS)

h5dset = h5group.create_dataset(
outname,
arrflat.shape,
chunks=chunks,
dtype=arrflat.dtype,
compression=compression,
**extra_kwargs,
)

# write in chunks, preserving sparsity if relevant
Expand All @@ -42,8 +65,15 @@ def writeSparse(indices, values, dense_shape, h5group, outname, maxChunkBytes=10
outgroup = h5group.create_group(outname)

nbytes = 0
# Index arrays compress extremely well (~10x for the tensor-sparse
# structures used by rabbit), so keep the default compression.
nbytes += writeFlatInChunks(indices, outgroup, "indices", maxChunkBytes)
nbytes += writeFlatInChunks(values, outgroup, "values", maxChunkBytes)
# Values of a sparse tensor are already densely packed nonzeros; real
# physics values typically give only ~4% compression gain at 5x the
# write cost, so skip compression here.
nbytes += writeFlatInChunks(
values, outgroup, "values", maxChunkBytes, compress=False
)
outgroup.attrs["dense_shape"] = np.array(dense_shape, dtype="int64")

return nbytes
73 changes: 70 additions & 3 deletions rabbit/inputdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,25 @@ def __init__(self, filename, pseudodata=None):
self.sparse = not "hnorm" in f

if self.sparse:
print(
"WARNING: The sparse tensor implementation is experimental and probably slower than with a dense tensor!"
)
self.norm = makesparsetensor(f["hnorm_sparse"])
self.logk = makesparsetensor(f["hlogk_sparse"])
# Canonicalize index ordering once at load time. The fitter's
# sparse fast path reduces nonzero entries via row-keyed
# reductions; sorted row-major indices give coalesced memory
# access. tf.sparse.reorder sorts into row-major order.
self.norm = tf.sparse.reorder(self.norm)
self.logk = tf.sparse.reorder(self.logk)
# Pre-build a CSRSparseMatrix view of logk for use in the
# fitter's sparse matvec path via sm.matmul, which dispatches
# to a multi-threaded CSR kernel and is much faster per call
# than the equivalent gather + unsorted_segment_sum. NOTE:
# SparseMatrixMatMul has no XLA kernel, so any tf.function
# that calls sm.matmul must be built with jit_compile=False.
from tensorflow.python.ops.linalg.sparse import (
sparse_csr_matrix_ops as _tf_sparse_csr,
)

self.logk_csr = _tf_sparse_csr.CSRSparseMatrix(self.logk)
else:
self.norm = maketensor(f["hnorm"])
self.logk = maketensor(f["hlogk"])
Expand Down Expand Up @@ -182,6 +196,59 @@ def __init__(self, filename, pseudodata=None):

self.axis_procs = hist.axis.StrCategory(self.procs, name="processes")

# Load external likelihood terms (optional).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also maybe go into a standalone function, but not sure about that

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah probably. The init of this is already not very modular so maybe can leave that for subsequent refactoring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok actually given the similar pattern in the fitter init I agree this can be modularized.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- moved into rabbit.external_likelihood.read_external_terms_from_h5(ext_group). FitInputData calls self.external_terms = read_external_terms_from_h5(f.get("external_terms")).

# Each entry is a dict with keys:
# name: str
# params: 1D ndarray of parameter name strings
# grad_values: 1D float ndarray or None
# hess_dense: 2D float ndarray or None
# hess_sparse: tuple (rows, cols, values) or None
self.external_terms = []
if "external_terms" in f.keys():
names = [
s.decode() if isinstance(s, bytes) else s
for s in f["hexternal_term_names"][...]
]
ext_group = f["external_terms"]
for tname in names:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the following work?

for tname, tg, in ext_group.items()

It would be more pythonic IMO and no need for storing the names separately

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that should work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the reader now uses for tname, tg in ext_group.items() (in rabbit.external_likelihood.read_external_terms_from_h5) and the writer-side names list is gone.

tg = ext_group[tname]
raw_params = tg["params"][...]
params = np.array(
[s.decode() if isinstance(s, bytes) else s for s in raw_params]
)
grad_values = (
np.asarray(maketensor(tg["grad_values"]))
if "grad_values" in tg.keys()
else None
)
hess_dense = (
np.asarray(maketensor(tg["hess_dense"]))
if "hess_dense" in tg.keys()
else None
)
hess_sparse = None
if "hess_sparse" in tg.keys():
hg = tg["hess_sparse"]
idx_dset = hg["indices"]
if "original_shape" in idx_dset.attrs:
idx_shape = tuple(idx_dset.attrs["original_shape"])
indices = np.asarray(idx_dset).reshape(idx_shape)
else:
indices = np.asarray(idx_dset)
rows = indices[:, 0]
cols = indices[:, 1]
vals = np.asarray(hg["values"])
hess_sparse = (rows, cols, vals)
self.external_terms.append(
{
"name": tname,
"params": params,
"grad_values": grad_values,
"hess_dense": hess_dense,
"hess_sparse": hess_sparse,
}
)

@tf.function
def expected_events_nominal(self):
rnorm = tf.ones(self.nproc, dtype=self.dtype)
Expand Down
18 changes: 18 additions & 0 deletions rabbit/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ def common_parser():
],
help="Mnimizer method used in scipy.optimize.minimize for the nominal fit minimization",
)
parser.add_argument(
"--hvpMethod",
default="revrev",
type=str,
choices=["fwdrev", "revrev"],
help="Autodiff mode for the Hessian-vector product. 'revrev' (reverse-over-reverse) "
"is the default and works well in combination with --jitCompile. 'fwdrev' "
"(forward-over-reverse, via tf.autodiff.ForwardAccumulator) is an alternative.",
)
parser.add_argument(
"--noJitCompile",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the option interplay with --eager? Do we need both options or can this also not be controlled with --eager? At a minimum --eager should also trigger no jit compile

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if --eager is used I think it will implicitly skip jitCompile because it won't even process the tf;functions

These are two different things though
tf function by itself switches from eager to graph mode. jit_compile recompiles the instructions within the graph (allowing things like add-multiply fusion like what a C++ compiler would do). Otherwise the graph internally still executes the tf operations one by one as written.

So yes we need both options and I think the current behaviour should be fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No code change here. --eager and --jitCompile control different things (eager mode skips graph building entirely, jit_compile is XLA fusion within an existing graph) so they remain orthogonal. --eager continues to bypass jit by virtue of skipping the tf.function wrappers.

dest="jitCompile",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the dest keyword can be confusing and so far we managed to do without it, could we keep that convention?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah ok let me see if that can be avoided

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- replaced --noJitCompile with a tri-state --jitCompile {auto,on,off} (default auto), no dest keyword needed.

default=True,
action="store_false",
help="Disable XLA jit_compile=True on the loss/gradient/HVP tf.functions. "
"jit_compile is enabled by default and substantially speeds up sparse-mode fits "
"with very large numbers of parameters.",
)
parser.add_argument(
"--chisqFit",
default=False,
Expand Down
Loading
Loading