Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4a03bf9
First version of regularization
davidwalter2 Feb 3, 2026
bf4221e
First implementation of curvature scan for regularization
davidwalter2 Feb 5, 2026
fd228b3
Add 'earlyStopping' feature to stop minimization if no reduction afte…
davidwalter2 Feb 5, 2026
a1a6f4d
Implement curvature scan and support for plotting it
davidwalter2 Feb 5, 2026
93f069d
Add flag to ensure numerical reproducebility
davidwalter2 Feb 5, 2026
1ecfbfa
Fix early stopping functionality
davidwalter2 Feb 5, 2026
fdc931d
work on lcurve optimization
davidwalter2 Feb 5, 2026
11fecdb
Add regularization test in CI; improve 'epoch' plotting script to be …
davidwalter2 Feb 6, 2026
bfa6c00
Few smaller fixes
davidwalter2 Feb 10, 2026
2995dee
Remove kaleido
davidwalter2 Feb 16, 2026
a000171
Fix betavariations in case of flow bins
davidwalter2 Feb 16, 2026
6e80a90
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Feb 16, 2026
b32e538
Run new version of black
davidwalter2 Feb 16, 2026
5041628
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Feb 16, 2026
01ebf6a
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Mar 2, 2026
c801750
Update parser discriptions
davidwalter2 Mar 7, 2026
c532558
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Mar 14, 2026
67c26f3
Putting safeguards for BB full
davidwalter2 Mar 15, 2026
04c3e88
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Mar 17, 2026
4f2436c
Merge branch '260202_regularization' of github.com:davidwalter2/rabbi…
davidwalter2 Mar 17, 2026
a8a32a0
Merge branch '260315_fixBB' of github.com:davidwalter2/rabbit into 26…
davidwalter2 Mar 17, 2026
441f026
Add warning for BB full with gamma
davidwalter2 Mar 17, 2026
f3b456b
Generalize data variances in chisq fit
davidwalter2 Mar 20, 2026
4b07c20
Merge branch '260320_data_variances_chisqfit' of github.com:davidwalt…
davidwalter2 Mar 20, 2026
a4abbce
Fix command line argument
davidwalter2 Mar 23, 2026
f6886c8
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Mar 26, 2026
9cf7ec7
Fix CI
davidwalter2 Mar 26, 2026
6b08932
Merge branch 'main' of github.com:WMass/rabbit into 260202_regulariza…
davidwalter2 Mar 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,33 @@ jobs:
--extraTextLoc '0.05' '0.7' --legCols 1 -m Project ch1 a -m Project ch1 b --yscale '1.2'
--subtitle "Work in progress" --config tests/style_config.py

regularization:
runs-on: [self-hosted, linux, x64]
needs: [setenv, make-tensor]
steps:
- env:
RABBIT_OUTDIR: ${{ needs.setenv.outputs.RABBIT_OUTDIR }}
PYTHONPATH: ${{ needs.setenv.outputs.PYTHONPATH }}
PATH: ${{ needs.setenv.outputs.PATH }}
WEB_DIR: ${{ needs.setenv.outputs.WEB_DIR }}
PLOT_DIR: ${{ needs.setenv.outputs.PLOT_DIR }}
run: |
echo "RABBIT_OUTDIR=${RABBIT_OUTDIR}" >> $GITHUB_ENV
echo "PYTHONPATH=${PYTHONPATH}" >> $GITHUB_ENV
echo "PATH=${PATH}" >> $GITHUB_ENV
echo "WEB_DIR=${WEB_DIR}" >> $GITHUB_ENV
echo "PLOT_DIR=${PLOT_DIR}" >> $GITHUB_ENV

- name: lcurve scan
run: >-
rabbit_fit.py $RABBIT_OUTDIR/test_tensor.hdf5 -o $RABBIT_OUTDIR/ --postfix lcurve_scan
-t 0 --unblind -r SVD Select 'ch0_masked' -r SVD Select 'ch0_masked' --lCurveScan --earlyStopping 15

- name: plot lcurve
run: >-
python tests/plot_epoch_loss_time.py $RABBIT_OUTDIR/fitresults_lcurve_scan.hdf5 -o $WEB_DIR/$PLOT_DIR
--types lcurve --title Experiment --subtitle 'Work in progress'

bsm:
runs-on: [self-hosted, linux, x64]
needs: [setenv, make-tensor]
Expand Down Expand Up @@ -518,7 +545,7 @@ jobs:

copy-clean:
runs-on: [self-hosted, linux, x64]
needs: [setenv, symmerizations, sparse-fits, alternative-fits, bsm, plotting, likelihoodscans]
needs: [setenv, symmerizations, sparse-fits, alternative-fits, regularization, bsm, plotting, likelihoodscans]
if: always()
steps:
- env:
Expand Down
55 changes: 52 additions & 3 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from rabbit.mappings import project
from rabbit.poi_models import helpers as ph
from rabbit.poi_models import poi_model
from rabbit.regularization import helpers as rh
from rabbit.regularization.lcurve import l_curve_optimize_tau, l_curve_scan_tau
from rabbit.tfhelpers import edmval_cov

from wums import output_tools, logging # isort: skip
Expand Down Expand Up @@ -165,6 +167,12 @@ def make_parser():
type=str,
help="Specify result from external postfit file",
)
parser.add_argument(
"--noFit",
default=False,
action="store_true",
help="Do not not perform the minimization.",
)
parser.add_argument(
"--noPostfitProfileBB",
default=False,
Expand Down Expand Up @@ -204,6 +212,24 @@ def make_parser():
action="store_true",
help="compute impacts of frozen (non-profiled) systematics",
)
parser.add_argument(
"--lCurveScan",
default=False,
action="store_true",
help="For use with regularization, scan the L curve versus values for tau",
)
parser.add_argument(
"--lCurveOptimize",
default=False,
action="store_true",
help="For use with regularization, find the value of tau that maximizes the curvature",
)
parser.add_argument(
"--regularizationStrength",
default=0.0,
type=float,
help="For use with regularization, set the regularization strength (tau)",
)

return parser

Expand Down Expand Up @@ -352,7 +378,21 @@ def fit(args, fitter, ws, dofit=True):
edmval = None

if args.externalPostfit is not None:
fitter.load_fitresult(args.externalPostfit, args.externalPostfitResult)
fitter.load_fitresult(
args.externalPostfit,
args.externalPostfitResult,
profile=not args.noPostfitProfileBB,
)

if args.lCurveScan:
tau_values, l_curve_values = l_curve_scan_tau(fitter)
ws.add_1D_integer_hist(tau_values, "step", "tau")
ws.add_1D_integer_hist(l_curve_values, "step", "lcurve")

if args.lCurveOptimize:
best_tau, max_curvature = l_curve_optimize_tau(fitter)
ws.add_1D_integer_hist([best_tau], "best", "tau")
ws.add_1D_integer_hist([max_curvature], "best", "lcurve")

if dofit:
cb = fitter.minimize()
Expand All @@ -364,7 +404,8 @@ def fit(args, fitter, ws, dofit=True):
fitter._profile_beta()

if cb is not None:
ws.add_loss_time_hist(cb.loss_history, cb.time_history)
ws.add_1D_integer_hist(cb.loss_history, "epoch", "loss")
ws.add_1D_integer_hist(cb.time_history, "epoch", "time")

if not args.noHessian:
# compute the covariance matrix and estimated distance to minimum
Expand Down Expand Up @@ -558,6 +599,14 @@ def main():
mp.CompositeMapping(mappings),
]

ifitter.tau.assign(args.regularizationStrength)
regularizers = []
for margs in args.regularization:
mapping = mh.load_mapping(margs[1], indata, *margs[2:])
regularizer = rh.load_regularizer(margs[0], mapping, dtype=indata.dtype)
regularizers.append(regularizer)
ifitter.regularizers = regularizers

np.random.seed(args.seed)
tf.random.set_seed(args.seed)

Expand Down Expand Up @@ -641,7 +690,7 @@ def main():

if not args.prefitOnly:
ifitter.set_blinding_offsets(blind=blinded_fits[i])
fit(args, ifitter, ws, dofit=ifit >= 0)
fit(args, ifitter, ws, dofit=ifit >= 0 and not args.noFit)
fit_time.append(time.time())

if args.saveHists:
Expand Down
57 changes: 47 additions & 10 deletions rabbit/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ def init_fit_parms(
name="cov",
)

# regularization
self.regularizers = []
# one common regularization strength parameter
self.tau = tf.Variable(1.0, trainable=True, name="tau", dtype=tf.float64)

# constraint minima for nuisance parameters
self.theta0 = tf.Variable(
self.theta0default,
Expand Down Expand Up @@ -372,7 +377,7 @@ def __deepcopy__(self, memo):
setattr(obj, k, copy.deepcopy(v, memo))
return obj

def load_fitresult(self, fitresult_file, fitresult_key):
def load_fitresult(self, fitresult_file, fitresult_key, profile=True):
# load results from external fit and set postfit value and covariance elements for common parameters
cov_ext = None
with h5py.File(fitresult_file, "r") as fext:
Expand Down Expand Up @@ -408,6 +413,9 @@ def load_fitresult(self, fitresult_file, fitresult_key):
covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)]
self.cov.assign(tf.constant(covval))

if profile:
self._profile_beta()

def update_frozen_params(self):
logger.debug(f"Updated list of frozen params: {self.frozen_params}")
new_mask_np = np.isin(self.parms, self.frozen_params)
Expand All @@ -417,12 +425,14 @@ def update_frozen_params(self):
self.floating_indices = np.where(~self.frozen_params_mask)[0]

def freeze_params(self, frozen_parmeter_expressions):
logger.debug(f"Freeze params with {frozen_parmeter_expressions}")
self.frozen_params.extend(
match_regexp_params(frozen_parmeter_expressions, self.parms)
)
self.update_frozen_params()

def defreeze_params(self, unfrozen_parmeter_expressions):
logger.debug(f"Freeze params with {unfrozen_parmeter_expressions}")
unfrozen_parmeter = match_regexp_params(
unfrozen_parmeter_expressions, self.parms
)
Expand All @@ -432,6 +442,7 @@ def defreeze_params(self, unfrozen_parmeter_expressions):
self.update_frozen_params()

def init_blinding_values(self, unblind_parameter_expressions=[]):
logger.debug(f"Unblind parameters with {unblind_parameter_expressions}")
unblind_parameters = match_regexp_params(
unblind_parameter_expressions,
[
Expand Down Expand Up @@ -524,6 +535,9 @@ def get_poi(self):
else:
return poi

def get_x(self):
return tf.concat([self.get_poi(), self.get_theta()], axis=0)

def _default_beta0(self):
if self.binByBinStatType in ["gamma", "normal-multiplicative"]:
return tf.ones(self.beta_shape, dtype=self.indata.dtype)
Expand Down Expand Up @@ -609,6 +623,11 @@ def defaultassign(self):
if self.do_blinding:
self.set_blinding_offsets(False)

xinit = self.get_x()
nexp0 = self.expected_yield(full=True)
for reg in self.regularizers:
reg.set_expectations(xinit, nexp0)

def bayesassign(self):
# FIXME use theta0 as the mean and constraintweight to scale the width
if self.poi_model.npoi == 0:
Expand Down Expand Up @@ -2011,13 +2030,7 @@ def _compute_lbeta(self, beta, full_nll=False):

return None

def _compute_nll_components(self, profile=True, full_nll=False):
nexp, _, beta = self._compute_yields_with_beta(
profile=profile,
compute_norm=False,
full=False,
)

def _compute_ln(self, nexp, full_nll=False):
if self.chisqFit:
ln = 0.5 * tf.reduce_sum((nexp - self.nobs) ** 2 / self.varnobs, axis=-1)
elif self.covarianceFit:
Expand Down Expand Up @@ -2045,22 +2058,46 @@ def _compute_nll_components(self, profile=True, full_nll=False):
ln = tf.reduce_sum(
-self.nobs * (lognexp - self.lognobs) + nexp - self.nobs, axis=-1
)
return ln

def _compute_nll_components(self, profile=True, full_nll=False):
nexpfullcentral, _, beta = self._compute_yields_with_beta(
profile=profile,
compute_norm=False,
full=len(self.regularizers),
)

nexp = nexpfullcentral[: self.indata.nbins]

ln = self._compute_ln(nexp, full_nll)

lc = self._compute_lc(full_nll)

lbeta = self._compute_lbeta(beta, full_nll)

return ln, lc, lbeta, beta
if len(self.regularizers):
x = self.get_x()
penalties = [
reg.compute_nll_penalty(x, nexpfullcentral) * tf.exp(2 * self.tau)
for reg in self.regularizers
]
lpenalty = tf.add_n(penalties)
else:
lpenalty = None

return ln, lc, lbeta, lpenalty, beta

def _compute_nll(self, profile=True, full_nll=False):
ln, lc, lbeta, beta = self._compute_nll_components(
ln, lc, lbeta, lpenalty, beta = self._compute_nll_components(
profile=profile, full_nll=full_nll
)
l = ln + lc

if lbeta is not None:
l = l + lbeta

if lpenalty is not None:
l = l + lpenalty
return l

def _compute_loss(self, profile=True):
Expand Down
13 changes: 13 additions & 0 deletions rabbit/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ def common_parser():
action="store_true",
help="Make a composite mapping and compute the covariance matrix across all mappings.",
)
parser.add_argument(
"-r",
"--regularization",
nargs="+",
action="append",
default=[],
help="""
apply regularization on the output "nout" of a mapping by including a penalty term P(nout) in the -log(L) of the minimization.
As argument, specify the regulaization defined in rabbit/regularization/, followed by a mapping using the same syntax as discussed above.
e.g. '-r SVD Select ch0_masked' to apply SVD regularization on the channel 'ch0_masked' or '-r SVD Project ch0 pt' for the 1D projection to pt.
Custom regularization can be specified with the full path e.g. '-r custom_regularization.MyCustomRegularization Project ch0 pt'.
""",
)

return parser

Expand Down
Empty file.
13 changes: 13 additions & 0 deletions rabbit/regularization/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from rabbit import common

# dictionary with class name and the corresponding filename where it is defined
baseline_regularizations = {
"SVD": "svd",
}


def load_regularizer(class_name, *args, **kwargs):
regularization = common.load_class_from_module(
class_name, baseline_regularizations, base_dir="rabbit.regularization"
)
return regularization(*args, **kwargs)
Loading