Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ def main():
"procs": ifitter.indata.procs,
"pois": ifitter.poi_model.pois,
"nois": ifitter.parms[ifitter.poi_model.npoi :][indata.noiidxs],
# Persist nuisance/group bookkeeping directly in fitresults
# so downstream plotting can reconstruct groupings without the input card.
"systs": ifitter.indata.systs,
"systgroups": ifitter.indata.systgroups,
"systgroupidxs": ifitter.indata.systgroupidxs,
"systsnoconstraint": ifitter.indata.systsnoconstraint,
}

with workspace.Workspace(
Expand Down
141 changes: 138 additions & 3 deletions bin/rabbit_plot_hists.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@
logger = None


def _decode_str(x):
return x.decode() if isinstance(x, (bytes, np.bytes_)) else str(x)


def _build_group_to_nuisance_map(meta):
# Prefer mapping persisted directly in fitresult meta, fallback to nested input meta.
systs = meta.get("systs")
groups = meta.get("systgroups")
groupidxs = meta.get("systgroupidxs")

if systs is None or groups is None or groupidxs is None:
nested = meta.get("meta_info_input", {})
systs = nested.get("systs")
groups = nested.get("systgroups")
groupidxs = nested.get("systgroupidxs")

if systs is None or groups is None or groupidxs is None:
return {}

systs = [_decode_str(x) for x in np.array(systs)]
groups = [_decode_str(x) for x in np.array(groups)]

out = {}
for g, idxs in zip(groups, np.array(groupidxs, dtype=object)):
nuis = []
for idx in np.array(idxs, dtype=int):
if 0 <= idx < len(systs):
nuis.append(systs[idx])
out[g] = nuis
return out


def parseArgs():

# choices for legend padding
Expand Down Expand Up @@ -331,6 +363,30 @@ def parseArgs():
Additional varNames can be specified to add variations from the nominal input.
""",
)
parser.add_argument(
"--varGroupNames",
type=str,
nargs="*",
default=None,
help=(
"Variation group names to build on-the-fly from nuisance variations "
"using fitresult meta mapping (systs/systgroups/systgroupidxs)."
),
)
parser.add_argument(
"--varGroupLabels",
type=str,
nargs="*",
default=None,
help="Label(s) for --varGroupNames.",
)
parser.add_argument(
"--varGroupColors",
type=str,
nargs="*",
default=None,
help="Color(s) for --varGroupNames.",
)
parser.add_argument(
"--varLabels",
type=str,
Expand Down Expand Up @@ -1174,6 +1230,9 @@ def make_plots(
varFilesFitTypes=None,
varMarkers=None,
varNames=None,
varGroupNames=None,
varGroupLabels=None,
varGroupColors=None,
varLabels=None,
varColors=None,
binwnorm=None,
Expand Down Expand Up @@ -1214,13 +1273,16 @@ def make_plots(
l if p not in args.suppressProcsLabel else None for l, p in zip(labels, procs)
]

if varNames is not None:
if varNames is not None or varGroupNames is not None:
# take the first variations from the varFiles, empty if no varFiles are provided
if len(varFilesFitTypes) == 1:
varFilesFitTypes = varFilesFitTypes * len(varResults)

hists_down = []
hists_up = []
plot_var_names = [] if varNames is None else list(varNames)
plot_var_labels = [] if varLabels is None else list(varLabels)
plot_var_colors = [] if varColors is None else list(varColors)
for r, t in zip(varResults, varFilesFitTypes):
h = r[f"hist_{t}_inclusive"].get()

Expand All @@ -1237,12 +1299,11 @@ def make_plots(
hists_up.append(hist_up)

# take the next variations from the nominal input file
if len(varNames) > len(varResults):
if varNames is not None and len(varNames) > len(varResults):
# variations from the nominal input file
hist_var = result[
f"hist_{fittype}_inclusive_variations{'_correlated' if args.correlatedVariations else ''}"
].get()

hists_down.extend(
[
hist_var[{"downUpVar": 0, "vars": n}].project(
Expand All @@ -1259,6 +1320,65 @@ def make_plots(
for n in varNames[len(varResults) :]
]
)

# grouped variations built on-the-fly from nuisance-level variations
if varGroupNames is not None and len(varGroupNames):
hist_var = result[
f"hist_{fittype}_inclusive_variations{'_correlated' if args.correlatedVariations else ''}"
].get()
axis_names = [a.name for a in axes]
var_axis_entries = {_decode_str(x) for x in hist_var.axes["vars"]}
group_to_nuis = _build_group_to_nuisance_map(kwopts.get("meta", {}))
h_nom = result[f"hist_{fittype}_inclusive"].get().project(*axis_names)

for ig, gname in enumerate(varGroupNames):
members = group_to_nuis.get(gname, [])
members = [n for n in members if n in var_axis_entries]
if len(members) == 0:
logger.warning(
f"Group '{gname}' has no matching nuisances in variations axis; skipping."
)
continue

sigma2 = np.zeros_like(h_nom.values(), dtype=float)
for n in members:
hdn = hist_var[{"downUpVar": 0, "vars": n}].project(*axis_names)
hup = hist_var[{"downUpVar": 1, "vars": n}].project(*axis_names)
contrib = 0.5 * (hup.values() - hdn.values())
sigma2 += contrib * contrib

sigma = np.sqrt(sigma2)
h_up = h_nom.copy()
h_down = h_nom.copy()
h_up.values()[...] = h_nom.values() + sigma
h_down.values()[...] = h_nom.values() - sigma
hists_up.append(h_up)
hists_down.append(h_down)

plot_var_names.append(gname)
if varGroupLabels is not None and ig < len(varGroupLabels):
plot_var_labels.append(varGroupLabels[ig])
else:
plot_var_labels.append(gname)
if varGroupColors is not None and ig < len(varGroupColors):
plot_var_colors.append(varGroupColors[ig])

# Ensure metadata arrays align with the effective number of plotted variations.
n_total = len(hists_up)
if len(plot_var_names) < n_total:
plot_var_names.extend([f"var{i}" for i in range(len(plot_var_names), n_total)])
if len(plot_var_labels) < n_total:
plot_var_labels.extend(plot_var_names[len(plot_var_labels) : n_total])
if len(plot_var_colors) < n_total:
default_cols = [
colormaps["tab10" if n_total < 10 else "tab20"](i)
for i in range(n_total)
]
plot_var_colors.extend(default_cols[len(plot_var_colors) : n_total])

varNames = plot_var_names
varLabels = plot_var_labels
varColors = plot_var_colors
else:
hists_down = None
hists_up = None
Expand Down Expand Up @@ -1418,6 +1538,7 @@ def main():

varFiles = args.varFiles
varNames = args.varNames
varGroupNames = args.varGroupNames
varLabels = args.varLabels
varColors = args.varColors
if varNames is not None:
Expand All @@ -1434,6 +1555,17 @@ def main():
colormaps["tab10" if len(varNames) < 10 else "tab20"](i)
for i in range(len(varNames))
]
varGroupLabels = args.varGroupLabels
varGroupColors = args.varGroupColors
if varGroupNames is not None:
if varGroupLabels is not None and len(varGroupLabels) != len(varGroupNames):
raise ValueError(
"Must specify the same number of args for --varGroupNames and --varGroupLabels"
)
if varGroupColors is not None and len(varGroupColors) != len(varGroupNames):
raise ValueError(
"Must specify the same number of args for --varGroupNames and --varGroupColors"
)

fittype = "prefit" if args.prefit else "postfit"

Expand Down Expand Up @@ -1470,6 +1602,9 @@ def main():
meta=meta,
fittype=fittype,
varNames=varNames,
varGroupNames=varGroupNames,
varGroupLabels=varGroupLabels,
varGroupColors=varGroupColors,
varLabels=varLabels,
varColors=varColors,
varMarkers=args.varMarkers,
Expand Down
Loading