diff --git a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h index 39030540f0697..22967ba2ced98 100644 --- a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h +++ b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h @@ -39,6 +39,9 @@ class RooJSONFactoryWSTool { public: static constexpr bool useListsInsteadOfDicts = true; static bool allowExportInvalidNames; + static bool allowSanitizeNames; + static RooWorkspace sanitizeWS(const RooWorkspace &ws); + static RooWorkspace cleanWS(const RooWorkspace &ws, bool onlyModelConfig = false); struct CombinedData { std::string name; @@ -52,11 +55,14 @@ class RooJSONFactoryWSTool { static std::string name(const RooFit::Detail::JSONNode &n); static bool isValidName(const std::string &str); static bool testValidName(const std::string &str, bool forcError); + static std::string sanitizeName(const std::string str); + static void rebuildModelConfigInWorkspace(RooStats::ModelConfig *mc, RooWorkspace &ws); static RooFit::Detail::JSONNode &appendNamedChild(RooFit::Detail::JSONNode &node, std::string const &name); static RooFit::Detail::JSONNode const *findNamedChild(RooFit::Detail::JSONNode const &node, std::string const &name); static void fillSeq(RooFit::Detail::JSONNode &node, RooAbsCollection const &coll, size_t nMax = -1); + static void fillSeqSanitizedName(RooFit::Detail::JSONNode &node, RooAbsCollection const &coll, size_t nMax = -1); template T *request(const std::string &objname, const std::string &requestAuthor) @@ -199,7 +205,6 @@ class RooJSONFactoryWSTool { private: template T *requestImpl(const std::string &objname); - void exportObject(RooAbsArg const &func, std::set &exportedObjectNames); // To export multiple objects sorted alphabetically @@ -230,7 +235,8 @@ class RooJSONFactoryWSTool { void exportAllObjects(RooFit::Detail::JSONNode &n); void exportModelConfig(RooFit::Detail::JSONNode &rootnode, RooStats::ModelConfig const &mc, - const std::vector &d); + const std::vector &combined, + const std::vector &single); void exportSingleModelConfig(RooFit::Detail::JSONNode &rootnode, RooStats::ModelConfig const &mc, std::string const &analysisName, diff --git a/roofit/hs3/src/JSONFactories_HistFactory.cxx b/roofit/hs3/src/JSONFactories_HistFactory.cxx index 7dd65826f2899..2634b752f3535 100644 --- a/roofit/hs3/src/JSONFactories_HistFactory.cxx +++ b/roofit/hs3/src/JSONFactories_HistFactory.cxx @@ -263,7 +263,6 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa } return *constraint; } else { - std::cout << "creating new constraint for " << param << std::endl; std::string constraint_type = "Gauss"; if (auto constrType = mod.find("constraint_type")) { constraint_type = constrType->val(); diff --git a/roofit/hs3/src/JSONFactories_RooFitCore.cxx b/roofit/hs3/src/JSONFactories_RooFitCore.cxx index 3dce98b344497..9c87e3add0056 100644 --- a/roofit/hs3/src/JSONFactories_RooFitCore.cxx +++ b/roofit/hs3/src/JSONFactories_RooFitCore.cxx @@ -12,12 +12,18 @@ #include +#include #include +#include #include #include #include #include +#include +#include #include +#include +#include #include #include #include @@ -33,7 +39,10 @@ #include #include #include +#include #include +#include +#include #include #include @@ -127,8 +136,23 @@ class RooAddPdfFactory : public RooFit::JSONIO::Importer { bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override { std::string name(RooJSONFactoryWSTool::name(p)); - tool->wsEmplace(name, tool->requestArgList(p, "summands"), - tool->requestArgList(p, "coefficients")); + if (!tool->requestArgList(p, "coefficients").empty()) { + tool->wsEmplace(name, tool->requestArgList(p, "summands"), + tool->requestArgList(p, "coefficients")); + return true; + } + tool->wsEmplace(name, tool->requestArgList(p, "summands")); + return true; + } +}; + +class RooAddModelFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + tool->wsEmplace(name, tool->requestArgList(p, "summands"), + tool->requestArgList(p, "coefficients")); return true; } }; @@ -240,6 +264,44 @@ class RooPoissonFactory : public RooFit::JSONIO::Importer { } }; +class RooDecayFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooRealVar *t = tool->requestArg(p, "t"); + RooAbsReal *tau = tool->requestArg(p, "tau"); + RooResolutionModel *model = dynamic_cast(tool->requestArg(p, "resolutionModel")); + RooDecay::DecayType decayType = static_cast(p["decayType"].val_int()); + tool->wsEmplace(name, *t, *tau, *model, decayType); + return true; + } +}; + +class RooTruthModelFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooRealVar *x = tool->requestArg(p, "x"); + tool->wsEmplace(name, *x); + return true; + } +}; + +class RooGaussModelFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooRealVar *x = tool->requestArg(p, "x"); + RooRealVar *mean = tool->requestArg(p, "mean"); + RooRealVar *sigma = tool->requestArg(p, "sigma"); + tool->wsEmplace(name, *x, *mean, *sigma); + return true; + } +}; + class RooRealIntegralFactory : public RooFit::JSONIO::Importer { public: bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override @@ -265,6 +327,62 @@ class RooRealIntegralFactory : public RooFit::JSONIO::Importer { } }; +class RooDerivativeFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooAbsReal *func = tool->requestArg(p, "function"); + RooRealVar *x = tool->requestArg(p, "x"); + Int_t order = p["order"].val_int(); + double eps = p["eps"].val_double(); + if (p.has_child("normalization")) { + RooArgSet normSet; + normSet.add(tool->requestArgSet(p, "normalization")); + tool->wsEmplace(name, *func, *x, normSet, order, eps); + return true; + } + tool->wsEmplace(name, *func, *x, order, eps); + return true; + } +}; + +class RooFFTConvPdfFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooRealVar *convVar = tool->requestArg(p, "conv_var"); + Int_t order = p["ipOrder"].val_int(); + RooAbsPdf *pdf1 = tool->requestArg(p, "pdf1"); + RooAbsPdf *pdf2 = tool->requestArg(p, "pdf2"); + if (p.has_child("conv_func")) { + RooAbsReal *convFunc = tool->requestArg(p, "conv_func"); + tool->wsEmplace(name, *convFunc, *convVar, *pdf1, *pdf2, order); + return true; + } + tool->wsEmplace(name, *convVar, *pdf1, *pdf2, order); + return true; + } +}; + +class RooExtendPdfFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooAbsPdf *pdf = tool->requestArg(p, "pdf"); + RooAbsReal *norm = tool->requestArg(p, "norm"); + if (p.has_child("range")) { + std::string rangeName = p["range"].val(); + tool->wsEmplace(name, *pdf, *norm, rangeName.c_str()); + return true; + } + tool->wsEmplace(name, *pdf, *norm); + return true; + } +}; + class RooLogNormalFactory : public RooFit::JSONIO::Importer { public: bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override @@ -417,17 +535,23 @@ class RooMultiVarGaussianFactory : public RooFit::JSONIO::Importer { /////////////////////////////////////////////////////////////////////////////////////////////////////// // specialized exporter implementations /////////////////////////////////////////////////////////////////////////////////////////////////////// - +template class RooAddPdfStreamer : public RooFit::JSONIO::Exporter { public: std::string const &key() const override; bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override { - const RooAddPdf *pdf = static_cast(func); + const RooArg_t *pdf = static_cast(func); elem["type"] << key(); + std::string name = elem["name"].val(); + /*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["summands"], pdf->pdfList()); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList()); + */ + elem["name"] << name; RooJSONFactoryWSTool::fillSeq(elem["summands"], pdf->pdfList()); RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList()); - elem["extended"] << (pdf->extendMode() != RooAbsPdf::CanNotBeExtended); + elem["extended"] << (pdf->extendMode() != RooArg_t::CanNotBeExtended); return true; } }; @@ -439,6 +563,12 @@ class RooRealSumPdfStreamer : public RooFit::JSONIO::Exporter { { const RooRealSumPdf *pdf = static_cast(func); elem["type"] << key(); + std::string name = elem["name"].val(); + /*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList()); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList()); + */ + elem["name"] << name; RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList()); RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList()); elem["extended"] << (pdf->extendMode() != RooAbsPdf::CanNotBeExtended); @@ -453,6 +583,12 @@ class RooRealSumFuncStreamer : public RooFit::JSONIO::Exporter { { const RooRealSumFunc *pdf = static_cast(func); elem["type"] << key(); + std::string name = elem["name"].val(); + /*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList()); + RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList()); + */ + elem["name"] << name; RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList()); RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList()); return true; @@ -625,6 +761,55 @@ class RooPoissonStreamer : public RooFit::JSONIO::Exporter { } }; +class RooDecayStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + elem["t"] << pdf->getT().GetName(); + elem["tau"] << pdf->getTau().GetName(); + elem["resolutionModel"] << pdf->getModel().GetName(); + elem["decayType"] << pdf->getDecayType(); + + return true; + } +}; + +class RooTruthModelStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + std::string name = elem["name"].val(); + // elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); + elem["name"] << name; + elem["x"] << pdf->convVar().GetName(); + + return true; + } +}; + +class RooGaussModelStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + std::string name = elem["name"].val(); + // elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); + elem["name"] << name; + elem["x"] << pdf->convVar().GetName(); + elem["mean"] << pdf->getMean().GetName(); + elem["sigma"] << pdf->getSigma().GetName(); + return true; + } +}; + class RooLogNormalStreamer : public RooFit::JSONIO::Exporter { public: std::string const &key() const override; @@ -704,6 +889,24 @@ class RooTFnBindingStreamer : public RooFit::JSONIO::Exporter { } }; +class RooDerivativeStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + elem["x"] << pdf->getX().GetName(); + elem["function"] << pdf->getFunc().GetName(); + if (!pdf->getNset().empty()) { + RooJSONFactoryWSTool::fillSeq(elem["normalization"], pdf->getNset()); + } + elem["order"] << pdf->order(); + elem["eps"] << pdf->eps(); + return true; + } +}; + class RooRealIntegralStreamer : public RooFit::JSONIO::Exporter { public: std::string const &key() const override; @@ -711,15 +914,13 @@ class RooRealIntegralStreamer : public RooFit::JSONIO::Exporter { { auto *integral = static_cast(func); std::string name = elem["name"].val(); - for (char& c : name ) { - if (c == '[' || c == '|' || c==',') { - c = '_'; - } - } - name.erase(std::remove(name.begin(), name.end(), ']'), name.end()); + // elem["name"] << RooJSONFactoryWSTool::sanitizeName(name); elem["name"] << name; + elem["type"] << key(); - elem["integrand"] << integral->integrand().GetName(); + std::string integrand = integral->integrand().GetName(); + // elem["integrand"] << RooJSONFactoryWSTool::sanitizeName(integrand); + elem["integrand"] << integrand; if (integral->intRange()) { elem["domain"] << integral->intRange(); } @@ -731,14 +932,50 @@ class RooRealIntegralStreamer : public RooFit::JSONIO::Exporter { } }; +class RooFFTConvPdfStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + if (auto convFunc = pdf->getPdfConvVar()) { + elem["conv_func"] << convFunc->GetName(); + } + elem["conv_var"] << pdf->getConvVar().GetName(); + elem["pdf1"] << pdf->getPdf1().GetName(); + elem["pdf2"] << pdf->getPdf2().GetName(); + elem["ipOrder"] << pdf->getInterpolationOrder(); + return true; + } +}; + +class RooExtendPdfStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + if (auto rangeName = pdf->getRangeName()) { + elem["range"] << rangeName->GetName(); + } + elem["pdf"] << pdf->pdf().GetName(); + elem["norm"] << pdf->getN().GetName(); + return true; + } +}; + #define DEFINE_EXPORTER_KEY(class_name, name) \ std::string const &class_name::key() const \ { \ const static std::string keystring = name; \ return keystring; \ } - -DEFINE_EXPORTER_KEY(RooAddPdfStreamer, "mixture_dist"); +template <> +DEFINE_EXPORTER_KEY(RooAddPdfStreamer, "mixture_dist"); +template <> +DEFINE_EXPORTER_KEY(RooAddPdfStreamer, "mixture_model"); DEFINE_EXPORTER_KEY(RooBinSamplingPdfStreamer, "binsampling"); DEFINE_EXPORTER_KEY(RooBinWidthFunctionStreamer, "binwidth"); DEFINE_EXPORTER_KEY(RooLegacyExpPolyStreamer, "legacy_exp_poly_dist"); @@ -752,6 +989,9 @@ DEFINE_EXPORTER_KEY(RooHistPdfStreamer, "histogram_dist"); DEFINE_EXPORTER_KEY(RooLogNormalStreamer, "lognormal_dist"); DEFINE_EXPORTER_KEY(RooMultiVarGaussianStreamer, "multivariate_normal_dist"); DEFINE_EXPORTER_KEY(RooPoissonStreamer, "poisson_dist"); +DEFINE_EXPORTER_KEY(RooDecayStreamer, "decay_dist"); +DEFINE_EXPORTER_KEY(RooTruthModelStreamer, "truth_model_function"); +DEFINE_EXPORTER_KEY(RooGaussModelStreamer, "gauss_model_function"); template <> DEFINE_EXPORTER_KEY(RooPolynomialStreamer, "polynomial_dist"); template <> @@ -760,6 +1000,9 @@ DEFINE_EXPORTER_KEY(RooRealSumFuncStreamer, "weighted_sum"); DEFINE_EXPORTER_KEY(RooRealSumPdfStreamer, "weighted_sum_dist"); DEFINE_EXPORTER_KEY(RooTFnBindingStreamer, "generic_function"); DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral"); +DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative"); +DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf"); +DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf"); /////////////////////////////////////////////////////////////////////////////////////////////////////// // instantiate all importers and exporters @@ -769,6 +1012,7 @@ STATIC_EXECUTE([]() { using namespace RooFit::JSONIO; registerImporter("mixture_dist", false); + registerImporter("mixture_model", false); registerImporter("binsampling_dist", false); registerImporter("binwidth", false); registerImporter("legacy_exp_poly_dist", false); @@ -780,13 +1024,20 @@ STATIC_EXECUTE([]() { registerImporter("lognormal_dist", false); registerImporter("multivariate_normal_dist", false); registerImporter("poisson_dist", false); + registerImporter("decay_dist", false); + registerImporter("truth_model_function", false); + registerImporter("gauss_model_function", false); registerImporter>("polynomial_dist", false); registerImporter>("polynomial", false); registerImporter("weighted_sum_dist", false); registerImporter("weighted_sum", false); registerImporter("integral", false); + registerImporter("derivative", false); + registerImporter("fft_conv_pdf", false); + registerImporter("extend_pdf", false); - registerExporter(RooAddPdf::Class(), false); + registerExporter>(RooAddPdf::Class(), false); + registerExporter>(RooAddModel::Class(), false); registerExporter(RooBinSamplingPdf::Class(), false); registerExporter(RooBinWidthFunction::Class(), false); registerExporter(RooLegacyExpPoly::Class(), false); @@ -798,12 +1049,18 @@ STATIC_EXECUTE([]() { registerExporter(RooLognormal::Class(), false); registerExporter(RooMultiVarGaussian::Class(), false); registerExporter(RooPoisson::Class(), false); + registerExporter(RooDecay::Class(), false); + registerExporter(RooTruthModel::Class(), false); + registerExporter(RooGaussModel::Class(), false); registerExporter>(RooPolynomial::Class(), false); registerExporter>(RooPolyVar::Class(), false); registerExporter(RooRealSumFunc::Class(), false); registerExporter(RooRealSumPdf::Class(), false); registerExporter(RooTFnBinding::Class(), false); registerExporter(RooRealIntegral::Class(), false); + registerExporter(RooDerivative::Class(), false); + registerExporter(RooFFTConvPdf::Class(), false); + registerExporter(RooExtendPdf::Class(), false); }); } // namespace diff --git a/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx b/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx index eeccc0c44f266..dacbf6dbf6456 100644 --- a/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx +++ b/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx @@ -108,6 +108,14 @@ auto RooFitHS3_wsexportkeys = R"({ "pdfs": "factors" } }, + "RooProjectedPdf": { + "type": "projected_dist", + "proxies": { + "IntegratedPdf": "input_pdf", + "IntegrationObservables": "observables", + "Dependents": "" + } + }, "RooProduct": { "type": "product", "proxies": { diff --git a/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx b/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx index 3230fa1ab9d5c..62c27a0038fca 100644 --- a/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx +++ b/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx @@ -105,6 +105,13 @@ auto RooFitHS3_wsfactoryexpressions = R"({ "factors" ] }, + "projected_dist": { + "class": "RooProjectedPdf", + "arguments": [ + "input_pdf", + "observables" + ] + }, "step": { "class": "ParamHistFunc", "arguments": [ diff --git a/roofit/hs3/src/RooJSONFactoryWSTool.cxx b/roofit/hs3/src/RooJSONFactoryWSTool.cxx index bd33647776271..e599ca8ec1b29 100644 --- a/roofit/hs3/src/RooJSONFactoryWSTool.cxx +++ b/roofit/hs3/src/RooJSONFactoryWSTool.cxx @@ -27,6 +27,7 @@ #include #include #include +#include #include "JSONIOUtils.h" #include "Domains.h" @@ -609,8 +610,6 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons auto *mc = static_cast(workspace.obj(mcname)); mc->SetWS(workspace); - std::vector nllDataNames; - auto *nllNode = RooJSONFactoryWSTool::findNamedChild(likelihoodsNode, analysisNode["likelihood"].val()); if (!nllNode) { throw std::runtime_error("likelihood node not found!"); @@ -630,12 +629,15 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons } RooArgSet observables; for (auto &nameNode : (*nllNode)["data"].children()) { - nllDataNames.push_back(nameNode.val()); + bool found = false; for (const auto &d : datasets) { if (d->GetName() == nameNode.val()) { + found = true; observables.add(*d->get()); } } + if (nameNode.val() != "0" && !found) + throw std::runtime_error("dataset '" + nameNode.val() + "' cannot be found!"); } JSONNode const *pdfNameNode = mcAuxNode ? mcAuxNode->find("pdfName") : nullptr; @@ -711,6 +713,7 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons nps.add(*p); } } + mc->SetGlobalObservables(globs); mc->SetNuisanceParameters(nps); @@ -719,6 +722,10 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons pdf->setStringAttribute("combined_data_name", found->val().c_str()); } } + + if (analysisNode.has_child("init") && workspace.getSnapshot(analysisNode["init"].val().c_str())) { + mc->SetSnapshot(*workspace.getSnapshot(analysisNode["init"].val().c_str())); + } } void combinePdfs(const JSONNode &rootnode, RooWorkspace &ws) @@ -833,6 +840,26 @@ void RooJSONFactoryWSTool::fillSeq(JSONNode &node, RooAbsCollection const &coll, } } +void RooJSONFactoryWSTool::fillSeqSanitizedName(JSONNode &node, RooAbsCollection const &coll, size_t nMax) +{ + const size_t old_children = node.num_children(); + node.set_seq(); + size_t n = 0; + for (RooAbsArg const *arg : coll) { + if (n >= nMax) + break; + if (isLiteralConstVar(*arg)) { + node.append_child() << static_cast(arg)->getVal(); + } else { + node.append_child() << sanitizeName(arg->GetName()); + } + ++n; + } + if (node.num_children() != old_children + coll.size()) { + error("unable to stream collection " + std::string(coll.GetName()) + " to " + node.key()); + } +} + JSONNode &RooJSONFactoryWSTool::appendNamedChild(JSONNode &node, std::string const &name) { if (!useListsInsteadOfDicts) { @@ -889,11 +916,13 @@ bool RooJSONFactoryWSTool::isValidName(const std::string &str) } bool RooJSONFactoryWSTool::allowExportInvalidNames(true); +bool RooJSONFactoryWSTool::allowSanitizeNames(true); bool RooJSONFactoryWSTool::testValidName(const std::string &name, bool forceError) { if (!RooJSONFactoryWSTool::isValidName(name)) { std::stringstream ss; - ss << "RooJSONFactoryWSTool() name '" << name << "' is not valid!" << std::endl; + ss << "RooJSONFactoryWSTool() name '" << name << "' is not valid!" << std::endl + << "Sanitize names by setting RooJSONFactoryWSTool::allowSanitizeNames = True." << std::endl; if (RooJSONFactoryWSTool::allowExportInvalidNames && !forceError) { RooJSONFactoryWSTool::warning(ss.str()); return false; @@ -1022,6 +1051,7 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node) void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n) { // export a list of RooRealVar objects + n.set_seq(); for (RooAbsArg *arg : allElems) { exportVariable(arg, n); } @@ -1051,7 +1081,8 @@ std::string RooJSONFactoryWSTool::exportTransformed(const RooAbsReal *original, */ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set &exportedObjectNames) { - const std::string name = func.GetName(); + // const std::string name = sanitizeName(func.GetName()); + std::string name = func.GetName(); // if this element was already exported, skip if (exportedObjectNames.find(name) != exportedObjectNames.end()) @@ -1327,9 +1358,9 @@ void RooJSONFactoryWSTool::exportHisto(RooArgSet const &vars, std::size_t n, dou auto &observablesNode = output["axes"].set_seq(); // axes have to be ordered to get consistent bin indices for (auto *var : static_range_cast(vars)) { - JSONNode &obsNode = observablesNode.append_child().set_map(); std::string name = var->GetName(); RooJSONFactoryWSTool::testValidName(name, false); + JSONNode &obsNode = observablesNode.append_child().set_map(); obsNode["name"] << name; if (var->getBinning().isUniform()) { obsNode["min"] << var->getMin(); @@ -1496,6 +1527,7 @@ RooJSONFactoryWSTool::CombinedData RooJSONFactoryWSTool::exportCombinedData(RooA void RooJSONFactoryWSTool::exportData(RooAbsData const &data) { // find category observables + RooAbsCategory *cat = nullptr; for (RooAbsArg *obs : *data.get()) { if (dynamic_cast(obs)) { @@ -1511,14 +1543,10 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) return; JSONNode &output = appendNamedChild((*_rootnodeOutput)["data"], data.GetName()); - - // this is a binned dataset - if (auto dh = dynamic_cast(&data)) { - output["type"] << "binned"; - return exportHisto(*dh->get(), dh->numEntries(), dh->weightArray(), output); - } - - // this is a regular unbinned dataset + /*std::ofstream file("/home/scello/Data/ZvvH126_5.txt", std::ios::app); + if (!file.is_open()) { + std::cerr << "Error: Could not open file for writing.\n"; + }*/ // This works around a problem in RooStats/HistFactory that was only fixed // in ROOT 6.30: until then, the weight variable of the observed dataset, @@ -1531,6 +1559,15 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) variables.remove(*weightVar); } + // this is a regular binned dataset + if (auto dh = dynamic_cast(&data)) { + output["type"] << "binned"; + for (auto *var : static_range_cast(variables)) { + _domains->readVariable(*var); + } + return exportHisto(variables, dh->numEntries(), dh->weightArray(), output); + } + // Check if this actually represents a binned dataset, and then import it // like a RooDataHist. This happens frequently when people create combined // RooDataSets from binned data to fit HistFactory models. In this case, it @@ -1552,21 +1589,27 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) isBinnedData = true; if (isBinnedData) { output["type"] << "binned"; + for (auto *var : static_range_cast(variables)) { + _domains->readVariable(*var); + } return exportHisto(variables, data.numEntries(), contents.data(), output); } } + // this really is an unbinned dataset output["type"] << "unbinned"; - - for (RooAbsArg *arg : variables) { - exportVariable(arg, output["axes"]); - } + exportVariables(variables, output["axes"]); auto &coords = output["entries"].set_seq(); std::vector weightVals; bool hasNonUnityWeights = false; for (int i = 0; i < data.numEntries(); ++i) { data.get(i); coords.append_child().fill_seq(variables, [](auto x) { return static_cast(x)->getVal(); }); + std::string datasetName = data.GetName(); + /*if (datasetName.find("combData_ZvvH126.5") != std::string::npos) { + file << dynamic_cast(data.get(i)->find("atlas_invMass_PttEtaConvVBFCat1"))->getVal() << + std::endl; + }*/ if (data.isWeighted()) { weightVals.push_back(data.weight()); if (data.weight() != 1.) @@ -1576,6 +1619,7 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) if (data.isWeighted() && hasNonUnityWeights) { output["weights"].fill_seq(weightVals); } + // file.close(); } /** @@ -1736,24 +1780,37 @@ void RooJSONFactoryWSTool::importDependants(const JSONNode &n) } void RooJSONFactoryWSTool::exportModelConfig(JSONNode &rootnode, RooStats::ModelConfig const &mc, - const std::vector &combDataSets) -{ - auto pdf = dynamic_cast(mc.GetPdf()); - if (pdf == nullptr) { - warning("RooFitHS3 only supports ModelConfigs with RooSimultaneous! Skipping ModelConfig."); - return; - } - - for (std::size_t i = 0; i < std::max(combDataSets.size(), std::size_t(1)); ++i) { - const bool hasdata = i < combDataSets.size(); - if (hasdata && !matches(combDataSets.at(i), pdf)) - continue; + const std::vector &combDataSets, + const std::vector &singleDataSets) +{ + auto pdf = mc.GetPdf(); + auto simpdf = dynamic_cast(pdf); + if (simpdf) { + for (std::size_t i = 0; i < std::max(combDataSets.size(), std::size_t(1)); ++i) { + const bool hasdata = i < combDataSets.size(); + if (hasdata && !matches(combDataSets.at(i), simpdf)) + continue; - std::string analysisName(pdf->GetName()); - if (hasdata) - analysisName += "_" + combDataSets[i].name; + std::string analysisName(simpdf->GetName()); + if (hasdata) + analysisName += "_" + combDataSets[i].name; - exportSingleModelConfig(rootnode, mc, analysisName, hasdata ? &combDataSets[i].components : nullptr); + exportSingleModelConfig(rootnode, mc, analysisName, hasdata ? &combDataSets[i].components : nullptr); + } + } else { + RooArgSet observables(*mc.GetObservables()); + int founddata = 0; + for (auto *data : singleDataSets) { + if (observables.equals(*(data->get()))) { + std::map mapping; + mapping[pdf->GetName()] = data->GetName(); + exportSingleModelConfig(rootnode, mc, std::string(pdf->GetName()) + "_" + data->GetName(), &mapping); + ++founddata; + } + } + if (founddata == 0) { + exportSingleModelConfig(rootnode, mc, pdf->GetName(), nullptr); + } } } @@ -1761,7 +1818,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: std::string const &analysisName, std::map const *dataComponents) { - auto pdf = static_cast(mc.GetPdf()); + auto pdf = mc.GetPdf(); JSONNode &analysisNode = appendNamedChild(rootnode["analyses"], analysisName); @@ -1774,11 +1831,22 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: nllNode["data"].set_seq(); if (dataComponents) { - for (auto const &item : pdf->indexCat()) { - const auto &dataComp = dataComponents->find(item.first); - nllNode["distributions"].append_child() << pdf->getPdf(item.first)->GetName(); - nllNode["data"].append_child() << dataComp->second; + auto simPdf = static_cast(pdf); + if (simPdf) { + for (auto const &item : simPdf->indexCat()) { + const auto &dataComp = dataComponents->find(item.first); + nllNode["distributions"].append_child() << simPdf->getPdf(item.first)->GetName(); + nllNode["data"].append_child() << dataComp->second; + } + } else { + for (auto it : *dataComponents) { + nllNode["distributions"].append_child() << it.first; + nllNode["data"].append_child() << it.second; + } } + } else { + nllNode["distributions"].append_child() << pdf->GetName(); + nllNode["data"].append_child() << 0; } if (mc.GetExternalConstraints()) { @@ -1790,7 +1858,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: } auto writeList = [&](const char *name, RooArgSet const *args) { - if (!args) + if (!args || !args->size()) return; std::vector names; @@ -1805,7 +1873,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: auto &domainsNode = rootnode["domains"]; - if (mc.GetNuisanceParameters()) { + if (mc.GetNuisanceParameters() && mc.GetNuisanceParameters()->size() > 0) { std::string npDomainName = analysisName + "_nuisance_parameters"; domains.append_child() << npDomainName; RooFit::JSONIO::Detail::Domains::ProductDomain npDomain; @@ -1815,7 +1883,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: npDomain.writeJSON(appendNamedChild(domainsNode, npDomainName)); } - if (mc.GetGlobalObservables()) { + if (mc.GetGlobalObservables() && mc.GetGlobalObservables()->size() > 0) { std::string globDomainName = analysisName + "_global_observables"; domains.append_child() << globDomainName; RooFit::JSONIO::Detail::Domains::ProductDomain globDomain; @@ -1825,7 +1893,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats: globDomain.writeJSON(appendNamedChild(domainsNode, globDomainName)); } - if (mc.GetParametersOfInterest()) { + if (mc.GetParametersOfInterest() && mc.GetParametersOfInterest()->size() > 0) { std::string poiDomainName = analysisName + "_parameters_of_interest"; domains.append_child() << poiDomainName; RooFit::JSONIO::Detail::Domains::ProductDomain poiDomain; @@ -1886,20 +1954,23 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n) exportAttributes(arg, n); } - // export all datasets + // collect all datasets std::vector alldata; for (auto &d : _workspace.allData()) { alldata.push_back(d); } sortByName(alldata); // first, take care of combined datasets + std::vector singleData; std::vector combData; for (auto &d : alldata) { auto data = this->exportCombinedData(*d); if (!data.components.empty()) combData.push_back(data); + else + singleData.push_back(d); } - // next, take care of regular datasets + // next, take care datasets for (auto &d : alldata) { this->exportData(*d); } @@ -1907,7 +1978,7 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n) // export all ModelConfig objects and attached Pdfs for (TObject *obj : _workspace.allGenericObjects()) { if (auto mc = dynamic_cast(obj)) { - exportModelConfig(n, *mc, combData); + exportModelConfig(n, *mc, combData, singleData); } } @@ -2368,3 +2439,250 @@ void RooJSONFactoryWSTool::error(const char *s) RooMsgService::instance().log(nullptr, RooFit::MsgLevel::ERROR, RooFit::IO) << s << std::endl; throw std::runtime_error(s); } + +/** + * @brief Cleans up names to the HS3 standard + * + * @param str The string to be sanitized. + * @return std::string + */ +std::string RooJSONFactoryWSTool::sanitizeName(const std::string str) +{ + std::string result; + if (RooJSONFactoryWSTool::allowSanitizeNames) { + for (char c : str) { + switch (c) { + case '[': + case '|': + case ',': + case '(': result += '_'; break; + case ']': + case ')': + // skip these characters entirely + break; + case '.': result += "_dot_"; break; + case '@': result += "at"; break; + case '-': result += "minus"; break; + case '/': result += "_div_"; break; + + default: result += c; break; + } + } + return result; + } + return str; +} + +RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyModelConfig) +{ + // Variables + + RooWorkspace tmpWS = RooWorkspace(); + if (onlyModelConfig) { + for (auto *obj : ws.allGenericObjects()) { + if (auto *mc = dynamic_cast(obj)) { + tmpWS.import(*mc->GetPdf(), RooFit::RecycleConflictNodes(true)); + } + } + + } else { + + for (auto *pdf : ws.allPdfs()) { + if (!pdf->hasClients()) { + tmpWS.import(*pdf, RooFit::RecycleConflictNodes(true)); + } + } + + for (auto *func : ws.allFunctions()) { + if (!func->hasClients()) { + tmpWS.import(*func, RooFit::RecycleConflictNodes(true)); + } + } + } + + for (auto *data : ws.allData()) { + tmpWS.import(*data); + } + + for (auto *obj : ws.allGenericObjects()) { + tmpWS.import(*obj); + } + + /* + if (auto* mc = dynamic_cast(obj)) { + // Import the PDF + tmpWS.import(*mc->GetPdf()); + + // Import all observables + RooArgSet* obs = (RooArgSet*)mc->GetObservables()->snapshot(); + tmpWS.import(*obs); + + // Import global observables + RooArgSet* globObs = (RooArgSet*)mc->GetGlobalObservables()->snapshot(); + tmpWS.import(*globObs); + + // Import POIs + RooArgSet* pois = (RooArgSet*)mc->GetParametersOfInterest()->snapshot(); + tmpWS.import(*pois); + + // Import nuisance parameters + RooArgSet* nuis = (RooArgSet*)mc->GetNuisanceParameters()->snapshot(); + tmpWS.import(*nuis); + + + RooStats::ModelConfig* mc_new = new RooStats::ModelConfig(mc->GetName(), mc->GetName()); + + mc_new->SetPdf(*tmpWS.pdf(mc->GetPdf()->GetName())); + mc_new->SetObservables(*tmpWS.set(obs->GetName())); + mc_new->SetGlobalObservables(*tmpWS.set(globObs->GetName())); + mc_new->SetParametersOfInterest(*tmpWS.set(pois->GetName())); + mc_new->SetNuisanceParameters(*tmpWS.set(nuis->GetName())); + + // Import the ModelConfig into the new workspace + tmpWS.import(*mc_new); + }else { + + tmpWS.import(*obj); + } + */ + + for (auto *snsh : ws.getSnapshots()) { + auto *snshSet = dynamic_cast(snsh); + if (snshSet) { + tmpWS.saveSnapshot(snshSet->GetName(), *snshSet, true); + } + } + + return tmpWS; +} + +// Sanitize all names in the workspace to be HS3 compliant +RooWorkspace RooJSONFactoryWSTool::sanitizeWS(const RooWorkspace &ws) +{ + // Variables + + RooWorkspace tmpWS = cleanWS(ws, false); + + for (auto *obj : tmpWS.allVars()) { + if (!isValidName(obj->GetName())) { + obj->SetName(sanitizeName(obj->GetName()).c_str()); + } + } + + // Functions + for (auto *obj : tmpWS.allFunctions()) { + if (!isValidName(obj->GetName())) { + obj->SetName(sanitizeName(obj->GetName()).c_str()); + } + } + + // PDFs + for (auto *obj : tmpWS.allPdfs()) { + if (!isValidName(obj->GetName())) { + obj->SetName(sanitizeName(obj->GetName()).c_str()); + } + } + + // Datasets + for (auto *data : tmpWS.allData()) { + // Sanitize dataset name + if (!isValidName(data->GetName())) { + data->SetName(sanitizeName(data->GetName()).c_str()); + } + for (auto *obj : *data->get()) { + obj->SetName(sanitizeName(obj->GetName()).c_str()); + } + } + /* // Sanitize dataset observables + const RooArgSet* obsSet = data->get(); + if (obsSet) { + RooArgSet* mutableObs = const_cast(obsSet); + std::string oldSetName = mutableObs->GetName(); + std::string newSetName = sanitizeName(oldSetName); + if (oldSetName != newSetName) { + mutableObs->setName(newSetName.c_str()); + } + } + + for (auto* arg : *obsSet) { + std::string oldObsName = arg->GetName(); + std::string newObsName = sanitizeName(oldObsName); + if (oldObsName != newObsName) { + arg->SetName(newObsName.c_str()); + data->changeObservableName(arg->GetName(), newObsName.c_str()); + } + } + */ + for (auto *data : tmpWS.allEmbeddedData()) { + // Sanitize dataset name + data->SetName(sanitizeName(data->GetName()).c_str()); + for (auto *obj : *data->get()) { + obj->SetName(sanitizeName(obj->GetName()).c_str()); + } + } + for (auto *snshObj : tmpWS.getSnapshots()) { + // Snapshots are stored as TObject*, but really they are RooArgSet* + auto *snsh = dynamic_cast(snshObj); + if (!snsh) { + std::cerr << "Warning: found snapshot that is not a RooArgSet, skipping\n"; + continue; + } + + // Sanitize snapshot name + if (!isValidName(snsh->GetName())) { + snsh->setName(sanitizeName(snsh->GetName()).c_str()); + } + + // Sanitize the variables inside the snapshot + for (auto *arg : *snsh) { + if (!isValidName(arg->GetName())) { + arg->SetName(sanitizeName(arg->GetName()).c_str()); + } + } + } + + // Generic objects (ModelConfigs, attributes, etc.) + for (auto *obj : tmpWS.allGenericObjects()) { + if (!isValidName(obj->GetName())) { + if (auto *named = dynamic_cast(obj)) { + named->SetName(sanitizeName(named->GetName()).c_str()); + } else { + std::cerr << "Warning: object " << obj->GetName() << " is not TNamed, cannot rename.\n"; + } + } + + if (auto *mc = dynamic_cast(obj)) { + // Sanitize ModelConfig name + if (!isValidName(mc->GetName())) { + mc->SetName(sanitizeName(mc->GetName()).c_str()); + } + + // Sanitize the sets inside ModelConfig + for (auto *obs : mc->GetObservables()->get()) { + if (obs) { + obs->SetName(sanitizeName(obs->GetName()).c_str()); + } + } + for (auto *poi : mc->GetParametersOfInterest()->get()) { + if (poi) { + poi->SetName(sanitizeName(poi->GetName()).c_str()); + } + } + for (auto *nuis : mc->GetNuisanceParameters()->get()) { + if (nuis) { + nuis->SetName(sanitizeName(nuis->GetName()).c_str()); + } + } + for (auto *glob : mc->GetGlobalObservables()->get()) { + if (glob) { + glob->SetName(sanitizeName(glob->GetName()).c_str()); + } + } + } + } + std::string wsName = std::string{ws.GetName()} + "_sanitized"; + RooWorkspace newWS = cleanWS(tmpWS, false); + newWS.SetName(wsName.c_str()); + + return newWS; +} diff --git a/roofit/hs3/test/testRooFitHS3.cxx b/roofit/hs3/test/testRooFitHS3.cxx index 2358a78f9c0d3..82e78feefb5cf 100644 --- a/roofit/hs3/test/testRooFitHS3.cxx +++ b/roofit/hs3/test/testRooFitHS3.cxx @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -32,7 +33,7 @@ namespace { // If the JSON files should be written out for debugging purpose. -const bool writeJsonFiles = false; +const bool writeJsonFiles = true; // Validate the JSON IO for a given RooAbsReal in a RooWorkspace. The workspace // will be written out and read back, and then the values of the old and new @@ -42,19 +43,24 @@ int validate(RooWorkspace &ws1, std::string const &argName, bool exact = true) { RooWorkspace ws2; + ws1.Print(); + const std::string json1 = RooJSONFactoryWSTool{ws1}.exportJSONtoString(); + + if (writeJsonFiles) { + RooJSONFactoryWSTool{ws1}.exportJSON(argName + "_1.json"); + } + RooJSONFactoryWSTool{ws2}.importJSONfromString(json1); + if (writeJsonFiles) { + RooJSONFactoryWSTool{ws2}.exportJSON(argName + "_2.json"); + } // Export the re-imported workspace back to JSON, and compare the first JSON // with the second one. They should be identical. const std::string json2 = RooJSONFactoryWSTool{ws2}.exportJSONtoString(); EXPECT_EQ(json2, json1) << argName; - if (writeJsonFiles) { - RooJSONFactoryWSTool{ws1}.exportJSON(argName + "_1.json"); - RooJSONFactoryWSTool{ws2}.exportJSON(argName + "_2.json"); - } - // It would be nice to do a similar closure check for the original and for // the re-imported workspace. However, there is no way to compare workspaces // for equality. But we can still check that the objects in the workspace @@ -70,32 +76,49 @@ int validate(RooWorkspace &ws1, std::string const &argName, bool exact = true) EXPECT_STREQ(comps1[i]->GetName(), comps2[i]->GetName()); } - RooRealVar &x1 = *ws1.var("x"); - RooRealVar &x2 = *ws2.var("x"); + RooRealVar *x1 = ws1.var("x"); + RooRealVar *x2 = ws2.var("x"); + + if (!x1 || !x2) + return 1; - RooAbsReal &arg1 = *ws1.function(argName); - RooAbsReal &arg2 = *ws2.function(argName); + TObject *arg1 = ws1.obj(argName); + TObject *arg2 = ws2.obj(argName); - RooArgSet nset1{x1}; - RooArgSet nset2{x2}; + if (!arg1 || !arg2) + return 1; - bool allGood = true; - for (int i = 0; i < x1.numBins(); ++i) { - x1.setBin(i); - x2.setBin(i); - const double val1 = arg1.getVal(nset1); - const double val2 = arg2.getVal(nset2); - allGood &= (exact ? (val1 == val2) : std::abs(val1 - val2) < 1e-10); + RooArgSet nset1{*x1}; + RooArgSet nset2{*x2}; + + RooAbsReal *r1 = dynamic_cast(arg1); + RooAbsReal *r2 = dynamic_cast(arg2); + + if (r1 && !r2) + return 1; + + if (r1 && r2) { + bool allGood = true; + for (int i = 0; i < x1->numBins(); ++i) { + x1->setBin(i); + x2->setBin(i); + const double val1 = r1->getVal(nset1); + const double val2 = r1->getVal(nset2); + allGood &= (exact ? (val1 == val2) : std::abs(val1 - val2) < 1e-10); + } + + return allGood ? 0 : 1; } - return allGood ? 0 : 1; + return 0; } int validate(std::vector const &expressions, bool exact = true) { RooWorkspace ws; for (std::size_t iExpr = 0; iExpr < expressions.size() - 1; ++iExpr) { - ws.factory(expressions[iExpr]); + std::cout << expressions[iExpr] << std::endl; + ws.factory(expressions[iExpr])->Print(); } const std::string argName = ws.factory(expressions.back())->GetName(); return validate(ws, argName, exact); @@ -437,3 +460,76 @@ TEST(RooFitHS3, ScientificNotation) RooJSONFactoryWSTool t2(newws); ASSERT_TRUE(t2.importJSONfromString(jsonStr)); } + +// Workspace with ONLY a dataset (here: RooDataHist to avoid extra includes). +// ----------------------------------------------------------------------------- +TEST(RooFitHS3, WorkspaceOnlyDataset_RooDataHist) +{ + RooWorkspace ws1{"ws_dataset_only"}; + + // Observable with explicit binning + RooRealVar x{"x", "x", 0.0, 1.0}; + x.setBins(3); + // Build a tiny RooDataHist + RooDataHist dh{"dh", "dataset-only (hist)", RooArgList{x}}; + // Fill deterministic contents + x.setVal(0.1666667); + dh.set(0, 10.0, 0.0); // bin 0 + x.setVal(0.5000000); + dh.set(1, 20.0, 0.0); // bin 1 + x.setVal(0.8333333); + dh.set(2, 15.0, 0.0); // bin 2 + + ws1.import(dh, RooFit::Silence()); + + // Round-trip and strict checks (no numeric comparison needed here) + // Use the dataset name for object tracking + const int status = validate(ws1, "dh"); + EXPECT_EQ(status, 0); +} + +// ----------------------------------------------------------------------------- +// Workspace with ONLY a function (no dataset, no pdfs). +// ----------------------------------------------------------------------------- +TEST(RooFitHS3, WorkspaceOnlyFunction) +{ + int status = validate({std::string("x[-3, 3]"), std::string("RooFormulaVar::myfunc(\"sin(x) + 0.5*x*x\",x)")}); + EXPECT_EQ(status, 0); +} + +// ----------------------------------------------------------------------------- +// Workspace with a ModelConfig that points to a multivariate Gaussian pdf. +// ----------------------------------------------------------------------------- +TEST(RooFitHS3, ModelConfigWithMultiVarGaussian) +{ + using RooFit::RooConst; + + // Observables + RooRealVar x{"x", "x", -5.0, 5.0}; + RooRealVar y{"y", "y", -5.0, 5.0}; + + // Means + RooRealVar mx{"mx", "mx", 0.5}; + RooRealVar my{"my", "my", -0.3}; + + // Covariance + TMatrixDSym cov{2}; + cov(0, 0) = 1.2; + cov(0, 1) = 0.25; + cov(1, 0) = 0.25; + cov(1, 1) = 0.9; + + RooMultiVarGaussian mv{"mvgauss", "mvgauss", RooArgList{x, y}, RooArgList{mx, my}, cov}; + + RooWorkspace ws1{"ws_mc"}; + ws1.import(mv, RooFit::Silence(), RooFit::RecycleConflictNodes()); + + // Build a ModelConfig referencing the pdf and its observables + RooStats::ModelConfig mc{"mc", &ws1}; + mc.SetPdf(*ws1.pdf("mvgauss")); + mc.SetObservables("x,y"); + ws1.import(mc); + + int status = validate(ws1, "mc"); + EXPECT_EQ(status, 0); +} diff --git a/roofit/roofit/inc/RooDecay.h b/roofit/roofit/inc/RooDecay.h index 525379554cd9d..8331238095617 100644 --- a/roofit/roofit/inc/RooDecay.h +++ b/roofit/roofit/inc/RooDecay.h @@ -33,6 +33,15 @@ class RooDecay : public RooAbsAnaConvPdf { double coefficient(Int_t basisIndex) const override ; + /// Get the cnvolution variable. + RooAbsReal const &getT() const { return _t.arg(); } + + /// Get the decay constant. + RooAbsReal const &getTau() const { return _tau.arg(); } + + /// Get the decay type. + DecayType getDecayType() const { return _type; } + Int_t getGenerator(const RooArgSet& directVars, RooArgSet &generateVars, bool staticInitOK=true) const override; void generateEvent(Int_t code) override; diff --git a/roofit/roofit/inc/RooGaussModel.h b/roofit/roofit/inc/RooGaussModel.h index 3e3d21661abd1..a53bca4f13d84 100644 --- a/roofit/roofit/inc/RooGaussModel.h +++ b/roofit/roofit/inc/RooGaussModel.h @@ -50,7 +50,13 @@ class RooGaussModel : public RooResolutionModel { bool canComputeBatchWithCuda() const override; -protected: + /// Get the mean parameter. + RooAbsReal const &getMean() const { return mean.arg(); } + + /// Get the sigma parameter. + RooAbsReal const &getSigma() const { return sigma.arg(); } + + protected: double evaluate() const override ; static double evaluate(double x, double mean, double sigma, double param1, double param2, int basisCode); diff --git a/roofit/roofitcore/inc/RooAbsAnaConvPdf.h b/roofit/roofitcore/inc/RooAbsAnaConvPdf.h index 9bdb5b18ca424..13fa4f5511ab6 100644 --- a/roofit/roofitcore/inc/RooAbsAnaConvPdf.h +++ b/roofit/roofitcore/inc/RooAbsAnaConvPdf.h @@ -80,6 +80,9 @@ class RooAbsAnaConvPdf : public RooAbsPdf { return const_cast(this)->convVar(); } + /// Get the resolution model. + RooAbsReal const &getModel() const { return _model.arg(); } + std::unique_ptr compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override; protected: diff --git a/roofit/roofitcore/inc/RooDerivative.h b/roofit/roofitcore/inc/RooDerivative.h index 52099602983ed..96307fb08c7a0 100644 --- a/roofit/roofitcore/inc/RooDerivative.h +++ b/roofit/roofitcore/inc/RooDerivative.h @@ -39,7 +39,15 @@ class RooDerivative : public RooAbsReal { TObject* clone(const char* newname=nullptr) const override { return new RooDerivative(*this, newname); } Int_t order() const { return _order ; } + double eps() const { return _eps ; } + + RooArgSet const &getNset() const { return _nset; } + + RooAbsReal const &getX() const { return *_x; } + + RooAbsReal const &getFunc() const { return *_func; } + void setEps(double e) { _eps = e ; } bool redirectServersHook(const RooAbsCollection& /*newServerList*/, bool /*mustReplaceAll*/, bool /*nameChange*/, bool /*isRecursive*/) override ; diff --git a/roofit/roofitcore/inc/RooExtendPdf.h b/roofit/roofitcore/inc/RooExtendPdf.h index 912a34dcdcb41..21706c59dd46c 100644 --- a/roofit/roofitcore/inc/RooExtendPdf.h +++ b/roofit/roofitcore/inc/RooExtendPdf.h @@ -50,7 +50,11 @@ class RooExtendPdf : public RooAbsPdf { RooAbsPdf const& pdf() const { return *_pdf; } -protected: + RooAbsReal const &getN() const { return *_n; } + + TNamed const *getRangeName() const { return _rangeName; } + + protected: RooTemplateProxy _pdf; ///< Input p.d.f RooTemplateProxy _n; ///< Number of expected events diff --git a/roofit/roofitcore/inc/RooFFTConvPdf.h b/roofit/roofitcore/inc/RooFFTConvPdf.h index 305b93fca5190..b544ad033b895 100644 --- a/roofit/roofitcore/inc/RooFFTConvPdf.h +++ b/roofit/roofitcore/inc/RooFFTConvPdf.h @@ -61,8 +61,15 @@ class RooFFTConvPdf : public RooAbsCachedPdf { Int_t getMaxVal(const RooArgSet& vars) const override { return _pdf1.arg().getMaxVal(vars) ; } double maxVal(Int_t code) const override { return _pdf1.arg().maxVal(code) ; } + RooAbsReal const &getConvVar() const { return *_x; } -protected: + RooAbsReal const &getPdf1() const { return *_pdf1; } + + RooAbsReal const &getPdf2() const { return *_pdf2; } + + RooAbsReal const *getPdfConvVar() const { return dynamic_cast(_xprime.absArg()); } + + protected: RooRealProxy _x ; ///< Convolution observable RooRealProxy _xprime ; ///< Input function representing value of convolution observable diff --git a/roofit/roofitcore/src/RooFFTConvPdf.cxx b/roofit/roofitcore/src/RooFFTConvPdf.cxx index cb57f2d473b05..bf7d882a9de28 100644 --- a/roofit/roofitcore/src/RooFFTConvPdf.cxx +++ b/roofit/roofitcore/src/RooFFTConvPdf.cxx @@ -241,8 +241,9 @@ RooFFTConvPdf::RooFFTConvPdf(const char *name, const char *title, RooRealVar &co //////////////////////////////////////////////////////////////////////////////// /// \copydoc RooFFTConvPdf(const char*, const char*, RooRealVar&, RooAbsPdf&, RooAbsPdf&, Int_t) -/// \param[in] pdfConvVar If the variable used for convolution is a PDF, itself, pass the PDF here, and pass the convolution variable to -/// `convVar`. See also rf210_angularconv.C in the roofit tutorials +/// \param[in] pdfConvVar If the variable used for convolution is a function, itself, pass the function here, and pass +/// the convolution variable to `convVar`. See also rf210_angularconv.C in the roofit tutorials RooFFTConvPdf::RooFFTConvPdf(const char *name, const char *title, RooAbsReal &pdfConvVar, RooRealVar &convVar, RooAbsPdf &pdf1, RooAbsPdf &pdf2, Int_t ipOrder)