Skip to content
Merged
56 changes: 55 additions & 1 deletion bin/rabbit_fit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import copy

import tensorflow as tf

tf.config.experimental.enable_op_determinism()
Expand All @@ -12,7 +14,9 @@
from rabbit import fitter, inputdata, parsing, workspace
from rabbit.mappings import helpers as mh
from rabbit.mappings import mapping as mp
from rabbit.mappings import project
from rabbit.poi_models import helpers as ph
from rabbit.poi_models import poi_model
from rabbit.tfhelpers import edmval_cov

from wums import output_tools, logging # isort: skip
Expand Down Expand Up @@ -137,6 +141,12 @@ def make_parser():
action="store_true",
help="save postfit histograms with each noi varied up to down",
)
parser.add_argument(
"--computeSaturatedProjectionTests",
default=False,
action="store_true",
help="Compute the saturated likelihood test for Project mappings",
)
parser.add_argument(
"--noChi2",
default=False,
Expand Down Expand Up @@ -248,7 +258,51 @@ def save_hists(args, mappings, fitter, ws, prefit=True, profile=False):
)

if aux[-2] is not None:
ws.add_chi2(aux[-2], aux[-1], prefit, mapping)
chi2val = float(aux[-2])
ndf = int(aux[-1])
p_val = chi2.sf(chi2val, ndf)

logger.info("Linear chi2:")
logger.info(f" ndof: {ndf}")
logger.info(f" chi2/ndf = {round(chi2val)}")
logger.info(rf" p-value: {round(p_val * 100, 2)}%")

ws.add_chi2(chi2val, ndf, prefit, mapping)

if (
not prefit
and type(mapping) == project.Project
and args.computeSaturatedProjectionTests
):
# saturated likelihood test

saturated_model = poi_model.SaturatedProjectModel(
fitter.indata, mapping.channel_info
)
composite_model = poi_model.CompositePOIModel(
[fitter.poi_model, saturated_model]
)

fitter_saturated = copy.deepcopy(fitter)
fitter_saturated.init_fit_parms(
composite_model,
args.setConstraintMinimum,
unblind=args.unblind,
freeze_parameters=args.freezeParameters,
)
cb = fitter_saturated.minimize()
nllvalreduced = fitter_saturated.reduced_nll().numpy()

ndf = saturated_model.npoi
chi2val = 2.0 * (ws.results["nllvalreduced"] - nllvalreduced)
p_val = chi2.sf(chi2val, ndf)

logger.info("Saturated chi2:")
logger.info(f" ndof: {ndf}")
logger.info(f" 2*deltaNLL: {round(chi2val, 2)}")
logger.info(rf" p-value: {round(p_val * 100, 2)}%")

ws.add_chi2(chi2val, ndf, prefit, mapping, saturated=True)

if args.saveHistsPerProcess and not mapping.skip_per_process:
logger.info(f"Save processes histogram for {mapping.key}")
Expand Down
78 changes: 56 additions & 22 deletions bin/rabbit_plot_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def make_parser():
default=["charge", "passIso", "passMT", "cosThetaStarll", "qGen"],
help="List of axes where for each bin a separate plot is created",
)
parser.add_argument(
"--dataCovariance",
action="store_true",
help="Use covariance information to plot the data uncertainty",
)
return parser


Expand All @@ -65,7 +70,7 @@ def plot_matrix(
matrix,
args,
channel=None,
axes=None,
axes_names=None,
cmap="coolwarm",
config={},
meta=None,
Expand Down Expand Up @@ -107,8 +112,8 @@ def plot_matrix(
**opts,
)
if ticklabels is None:
xlabel = plot_tools.get_axis_label(config, axes, args.xlabel, is_bin=True)
ylabel = plot_tools.get_axis_label(config, axes, args.ylabel, is_bin=True)
xlabel = plot_tools.get_axis_label(config, axes_names, args.xlabel, is_bin=True)
ylabel = plot_tools.get_axis_label(config, axes_names, args.ylabel, is_bin=True)

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
Expand All @@ -124,11 +129,14 @@ def plot_matrix(
)

to_join = [f"hist_{'corr' if args.correlation else 'cov'}"]
to_join.append("prefit" if args.prefit else "postfit")
if args.dataCovariance:
to_join.append("data")
else:
to_join.append("prefit" if args.prefit else "postfit")
if channel is not None:
to_join.append(channel)
if axes is not None:
to_join.append("_".join(axes))
if axes_names is not None:
to_join.append("_".join(axes_names))
if suffix is not None:
to_join.append(suffix)
to_join = [*to_join, args.postfix]
Expand Down Expand Up @@ -177,7 +185,7 @@ def main():

if args.params is not None:
h_cov = fitresult["cov"].get()
axes = h_cov.axes.name
axes_names = h_cov.axes.name

if len(args.params) > 0:
h_param = fitresult["parms"].get()
Expand All @@ -203,32 +211,34 @@ def main():
outdir,
h_cov,
args,
axes=axes,
axes_names=axes_names,
config=config,
meta=meta,
suffix="params",
ticklabels=ticklabels,
)

hist_cov_key = f"hist_{'prefit' if args.prefit else 'postfit'}_inclusive_cov"
if args.dataCovariance:
hist_cov_key = "cov_data_obs"
else:
hist_cov_key = f"hist_{'prefit' if args.prefit else 'postfit'}_inclusive_cov"

results = fitresult.get("mappings", fitresult.get("physics_models"))
for margs in args.mapping:

if margs == []:
instance_keys = results.keys()
else:
mapping_key = " ".join(margs)
instance_keys = [k for k in results.keys() if k.startswith(mapping_key)]
if len(instance_keys) == 0:
raise ValueError(
f"No mapping found under {mapping_key}, available mappings are {results.keys()}"
f"No mapping found under {mapping_key}; available mappings are {results.keys()}"
)

for instance_key in instance_keys:
instance = results[instance_key]

h_cov = instance[hist_cov_key].get()

suffix = (
instance_key.replace(" ", "_")
.replace(".", "p")
Expand All @@ -238,20 +248,43 @@ def main():
.replace(")", "")
)

plot_matrix(
outdir,
h_cov,
args,
config=config,
meta=meta,
suffix=suffix,
)
if hist_cov_key in instance.keys():
h_cov = instance[hist_cov_key].get()
plot_matrix(
outdir,
h_cov,
args,
config=config,
meta=meta,
suffix=suffix,
)
else:
h_cov = None

start = 0
for channel, info in instance["channels"].items():
channel_hist = info[f"hist_postfit_inclusive"].get()
channel_hist = info["hist_postfit_inclusive"].get()
axes = [a for a in channel_hist.axes]
if len(instance.get("channels", {}).keys()) > 1:
axes_names = channel_hist.axes.name

if h_cov is None:
if hist_cov_key not in info.keys():
raise ValueError(
f"No key {hist_cov_key}; available; keys are {info.keys()}"
)
h_cov = info[hist_cov_key].get()

plot_matrix(
outdir,
h_cov,
args,
channel=channel,
config=config,
meta=meta,
suffix=suffix,
axes_names=axes_names,
)
elif len(instance.get("channels", {}).keys()) > 1:
# plot covariance matrix in each channel
nbins = np.prod(channel_hist.shape)
stop = int(start + nbins)
Expand All @@ -268,6 +301,7 @@ def main():
config=config,
meta=meta,
suffix=suffix,
axes_names=axes_names,
)
else:
h_cov_channel = h_cov
Expand Down
Loading