diff --git a/bindings/pyroot/pythonizations/test/CMakeLists.txt b/bindings/pyroot/pythonizations/test/CMakeLists.txt index 14e864693921f..026587b148380 100644 --- a/bindings/pyroot/pythonizations/test/CMakeLists.txt +++ b/bindings/pyroot/pythonizations/test/CMakeLists.txt @@ -172,13 +172,13 @@ if(roofit) endif() if (dataframe) - # std::string_view in CPyCppyy - ROOT_ADD_PYUNITTEST(pyroot_string_view string_view.py) if(NOT MSVC OR win_broken_tests) # Test wrapping Python callables for use in C++ using numba ROOT_ADD_PYUNITTEST(pyroot_numbadeclare numbadeclare.py PYTHON_DEPS numba) ROOT_ADD_PYUNITTEST(pyroot_rdf_filter_pyz rdf_filter_pyz.py PYTHON_DEPS numba) ROOT_ADD_PYUNITTEST(pyroot_rdf_define_pyz rdf_define_pyz.py PYTHON_DEPS numba) + # std::string_view in CPyCppyy + ROOT_ADD_PYUNITTEST(pyroot_string_view string_view.py) endif() endif() diff --git a/tree/dataframe/CMakeLists.txt b/tree/dataframe/CMakeLists.txt index b60854ea4ff7f..b68217e95c65c 100644 --- a/tree/dataframe/CMakeLists.txt +++ b/tree/dataframe/CMakeLists.txt @@ -49,6 +49,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame ROOT/RRootDS.hxx ROOT/RSnapshotOptions.hxx ROOT/RTrivialDS.hxx + ROOT/RTTreeDS.hxx ROOT/RDF/ActionHelpers.hxx ROOT/RDF/ColumnReaderUtils.hxx ROOT/RDF/GraphNode.hxx @@ -107,6 +108,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame src/RCutFlowReport.cxx src/RDataFrame.cxx src/RDatasetSpec.cxx + src/RDataSource.cxx src/RDFActionHelpers.cxx src/RDFColumnReaderUtils.cxx src/RDFColumnRegister.cxx @@ -129,6 +131,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame src/RSample.cxx src/RTreeColumnReader.cxx src/RResultPtr.cxx + src/RTTreeDS.cxx src/RVariationBase.cxx src/RVariationReader.cxx src/RVariationsDescription.cxx diff --git a/tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx b/tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx index 69d7dea2c09b2..6a89a77643dbb 100644 --- a/tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx +++ b/tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx @@ -52,6 +52,7 @@ #include "ROOT/RNTupleDS.hxx" #include "ROOT/RNTupleWriter.hxx" // for SnapshotRNTupleHelper #endif +#include "ROOT/RTTreeDS.hxx" #include #include @@ -1530,12 +1531,15 @@ class R__CLING_PTRCHECK(off) SnapshotTTreeHelper : public RActionImpl fBranchAddresses; // Addresses of objects associated to output branches RBranchSet fOutputBranches; std::vector fIsDefine; + ROOT::Detail::RDF::RLoopManager *fOutputLoopManager; + ROOT::RDF::RDataSource *fInputDataSource; public: using ColumnTypes_t = TypeList; SnapshotTTreeHelper(std::string_view filename, std::string_view dirname, std::string_view treename, const ColumnNames_t &vbnames, const ColumnNames_t &bnames, const RSnapshotOptions &options, - std::vector &&isDefine) + std::vector &&isDefine, ROOT::Detail::RDF::RLoopManager *loopManager, + ROOT::RDF::RDataSource *inputDataSource) : fFileName(filename), fDirName(dirname), fTreeName(treename), @@ -1544,7 +1548,9 @@ public: fOutputBranchNames(ReplaceDotWithUnderscore(bnames)), fBranches(vbnames.size(), nullptr), fBranchAddresses(vbnames.size(), nullptr), - fIsDefine(std::move(isDefine)) + fIsDefine(std::move(isDefine)), + fOutputLoopManager(loopManager), + fInputDataSource(inputDataSource) { EnsureValidSnapshotTTreeOutput(fOptions, fTreeName, fFileName); } @@ -1571,6 +1577,8 @@ public: { if (r) fInputTree = r->GetTree(); + else if (auto treeDS = dynamic_cast(fInputDataSource)) + fInputTree = treeDS->GetTree(); fBranchAddressesNeedReset = true; } @@ -1650,6 +1658,10 @@ public: // must destroy the TTree first, otherwise TFile will delete it too leading to a double delete fOutputTree.reset(); fOutputFile->Close(); + + // Now connect the data source to the loop manager so it can be used for further processing + auto fullTreeName = fDirName.empty() ? fTreeName : fDirName + '/' + fTreeName; + fOutputLoopManager->SetDataSource(std::make_unique(fullTreeName, fFileName)); } std::string GetActionName() { return "Snapshot"; } @@ -1673,8 +1685,15 @@ public: SnapshotTTreeHelper MakeNew(void *newName, std::string_view /*variation*/ = "nominal") { const std::string finalName = *reinterpret_cast(newName); - return SnapshotTTreeHelper{ - finalName, fDirName, fTreeName, fInputBranchNames, fOutputBranchNames, fOptions, std::vector(fIsDefine)}; + return SnapshotTTreeHelper{finalName, + fDirName, + fTreeName, + fInputBranchNames, + fOutputBranchNames, + fOptions, + std::vector(fIsDefine), + fOutputLoopManager, + fInputDataSource}; } }; @@ -1699,12 +1718,16 @@ class R__CLING_PTRCHECK(off) SnapshotTTreeHelperMT : public RActionImpl> fBranchAddresses; std::vector fOutputBranches; std::vector fIsDefine; + ROOT::Detail::RDF::RLoopManager *fOutputLoopManager; + ROOT::RDF::RDataSource *fInputDataSource; public: using ColumnTypes_t = TypeList; + SnapshotTTreeHelperMT(const unsigned int nSlots, std::string_view filename, std::string_view dirname, std::string_view treename, const ColumnNames_t &vbnames, const ColumnNames_t &bnames, - const RSnapshotOptions &options, std::vector &&isDefine) + const RSnapshotOptions &options, std::vector &&isDefine, + ROOT::Detail::RDF::RLoopManager *loopManager, ROOT::RDF::RDataSource *inputDataSource) : fNSlots(nSlots), fOutputFiles(fNSlots), fOutputTrees(fNSlots), @@ -1719,7 +1742,9 @@ public: fBranches(fNSlots, std::vector(vbnames.size(), nullptr)), fBranchAddresses(fNSlots, std::vector(vbnames.size(), nullptr)), fOutputBranches(fNSlots), - fIsDefine(std::move(isDefine)) + fIsDefine(std::move(isDefine)), + fOutputLoopManager(loopManager), + fInputDataSource(inputDataSource) { EnsureValidSnapshotTTreeOutput(fOptions, fTreeName, fFileName); } @@ -1766,7 +1791,9 @@ public: if (r) { // not an empty-source RDF fInputTrees[slot] = r->GetTree(); - } + } else if (auto treeDS = dynamic_cast(fInputDataSource)) + fInputTrees[slot] = treeDS->GetTree(); + fBranchAddressesNeedReset[slot] = 1; // reset first event flag for this slot } @@ -1855,6 +1882,10 @@ public: // flush all buffers to disk by destroying the TBufferMerger fOutputFiles.clear(); fMerger.reset(); + + // Now connect the data source to the loop manager so it can be used for further processing + auto fullTreeName = fDirName.empty() ? fTreeName : fDirName + '/' + fTreeName; + fOutputLoopManager->SetDataSource(std::make_unique(fullTreeName, fFileName)); } std::string GetActionName() { return "Snapshot"; } @@ -1878,8 +1909,16 @@ public: SnapshotTTreeHelperMT MakeNew(void *newName, std::string_view /*variation*/ = "nominal") { const std::string finalName = *reinterpret_cast(newName); - return SnapshotTTreeHelperMT{fNSlots, finalName, fDirName, fTreeName, - fInputBranchNames, fOutputBranchNames, fOptions, std::vector(fIsDefine)}; + return SnapshotTTreeHelperMT{fNSlots, + finalName, + fDirName, + fTreeName, + fInputBranchNames, + fOutputBranchNames, + fOptions, + std::vector(fIsDefine), + fOutputLoopManager, + fInputDataSource}; } }; @@ -1907,7 +1946,7 @@ class R__CLING_PTRCHECK(off) SnapshotRNTupleHelper : public RActionImpl fOutputFile{nullptr}; RSnapshotOptions fOptions; - ROOT::Detail::RDF::RLoopManager *fLoopManager; + ROOT::Detail::RDF::RLoopManager *fOutputLoopManager; ColumnNames_t fInputFieldNames; // This contains the resolved aliases ColumnNames_t fOutputFieldNames; std::unique_ptr fWriter{nullptr}; @@ -1925,7 +1964,7 @@ public: fDirName(dirname), fNTupleName(ntuplename), fOptions(options), - fLoopManager(lm), + fOutputLoopManager(lm), fInputFieldNames(vfnames), fOutputFieldNames(ReplaceDotWithUnderscore(fnames)), fIsDefine(std::move(isDefine)) @@ -1939,7 +1978,7 @@ public: SnapshotRNTupleHelper &operator=(SnapshotRNTupleHelper &&) = default; ~SnapshotRNTupleHelper() { - if (!fNTupleName.empty() && !fLoopManager->GetDataSource() && fOptions.fLazy) + if (!fNTupleName.empty() && !fOutputLoopManager->GetDataSource() && fOptions.fLazy) Warning("Snapshot", "A lazy Snapshot action was booked but never triggered."); } @@ -1999,7 +2038,7 @@ public: { fWriter.reset(); // We can now set the data source of the loop manager for the RDataFrame that is returned by the Snapshot call. - fLoopManager->SetDataSource( + fOutputLoopManager->SetDataSource( std::make_unique(fDirName + "/" + fNTupleName, fFileName)); } @@ -2029,7 +2068,7 @@ public: fInputFieldNames, fOutputFieldNames, fOptions, - fLoopManager, + fOutputLoopManager, std::vector(fIsDefine)}; } }; diff --git a/tree/dataframe/inc/ROOT/RDF/ColumnReaderUtils.hxx b/tree/dataframe/inc/ROOT/RDF/ColumnReaderUtils.hxx index 76ef767fb9afb..6042b9e29ae21 100644 --- a/tree/dataframe/inc/ROOT/RDF/ColumnReaderUtils.hxx +++ b/tree/dataframe/inc/ROOT/RDF/ColumnReaderUtils.hxx @@ -31,6 +31,8 @@ #include // for typeid #include +class TTreeReader; + namespace ROOT { namespace Internal { namespace RDF { @@ -56,7 +58,7 @@ struct RColumnReadersInfo { /// Create a group of column readers, one per type in the parameter pack. template std::array -GetColumnReaders(unsigned int slot, TTreeReader *r, TypeList, const RColumnReadersInfo &colInfo, +GetColumnReaders(unsigned int slot, TTreeReader *treeReader, TypeList, const RColumnReadersInfo &colInfo, const std::string &variationName = "nominal") { // see RColumnReadersInfo for why we pass these arguments like this rather than directly as function arguments @@ -65,9 +67,10 @@ GetColumnReaders(unsigned int slot, TTreeReader *r, TypeList, const auto &colRegister = colInfo.fColRegister; int i = -1; + std::array ret{ - (++i, GetColumnReader(slot, colRegister.GetReader(slot, colNames[i], variationName, typeid(ColTypes)), lm, r, - colNames[i], typeid(ColTypes)))...}; + (++i, GetColumnReader(slot, colRegister.GetReader(slot, colNames[i], variationName, typeid(ColTypes)), lm, + treeReader, colNames[i], typeid(ColTypes)))...}; return ret; } diff --git a/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx b/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx index 93080dd7d3a4b..a0fe25c076982 100644 --- a/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx +++ b/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx @@ -250,7 +250,8 @@ struct SnapshotHelperArgs { std::string fTreeName; std::vector fOutputColNames; ROOT::RDF::RSnapshotOptions fOptions; - RDFDetail::RLoopManager *fLoopManager; + ROOT::Detail::RDF::RLoopManager *fLoopManager; + ROOT::RDF::RDataSource *fDataSource; bool fToNTuple; }; @@ -266,6 +267,8 @@ BuildAction(const ColumnNames_t &colNames, const std::shared_ptrfTreeName; const auto &outputColNames = snapHelperArgs->fOutputColNames; const auto &options = snapHelperArgs->fOptions; + const auto &lmPtr = snapHelperArgs->fLoopManager; + const auto &dataSource = snapHelperArgs->fDataSource; auto sz = sizeof...(ColTypes); std::vector isDefine(sz); @@ -280,10 +283,8 @@ BuildAction(const ColumnNames_t &colNames, const std::shared_ptr; using Action_t = RAction; - auto loopManager = snapHelperArgs->fLoopManager; - actionPtr.reset(new Action_t( - Helper_t(filename, dirname, treename, colNames, outputColNames, options, loopManager, std::move(isDefine)), + Helper_t(filename, dirname, treename, colNames, outputColNames, options, lmPtr, std::move(isDefine)), colNames, prevNode, colRegister)); } else { // multi-thread snapshot to RNTuple is not yet supported @@ -302,16 +303,16 @@ BuildAction(const ColumnNames_t &colNames, const std::shared_ptr; using Action_t = RAction; - actionPtr.reset( - new Action_t(Helper_t(filename, dirname, treename, colNames, outputColNames, options, std::move(isDefine)), - colNames, prevNode, colRegister)); + actionPtr.reset(new Action_t(Helper_t(filename, dirname, treename, colNames, outputColNames, options, + std::move(isDefine), lmPtr, dataSource), + colNames, prevNode, colRegister)); } else { // multi-thread snapshot using Helper_t = SnapshotTTreeHelperMT; using Action_t = RAction; - actionPtr.reset(new Action_t( - Helper_t(nSlots, filename, dirname, treename, colNames, outputColNames, options, std::move(isDefine)), - colNames, prevNode, colRegister)); + actionPtr.reset(new Action_t(Helper_t(nSlots, filename, dirname, treename, colNames, outputColNames, options, + std::move(isDefine), lmPtr, dataSource), + colNames, prevNode, colRegister)); } } return actionPtr; @@ -412,8 +413,15 @@ std::vector FindUndefinedDSColumns(const ColumnNames_t &requestedCols, con template void AddDSColumnsHelper(const std::string &colName, RLoopManager &lm, RDataSource &ds, RColumnRegister &colRegister) { - if (colRegister.IsDefineOrAlias(colName) || !ds.HasColumn(colName) || - lm.HasDataSourceColumnReaders(colName, typeid(T))) + + if (colRegister.IsDefineOrAlias(colName)) + return; + + if (lm.HasDataSourceColumnReaders(colName, typeid(T))) + return; + + if (!ds.HasColumn(colName) && + lm.GetSuppressErrorsForMissingBranches().find(colName) == lm.GetSuppressErrorsForMissingBranches().end()) return; const auto nSlots = lm.GetNSlots(); @@ -428,7 +436,8 @@ void AddDSColumnsHelper(const std::string &colName, RLoopManager &lm, RDataSourc } else { // using the new GetColumnReaders mechanism // TODO consider changing the interface so we return all of these for all slots in one go for (auto slot = 0u; slot < lm.GetNSlots(); ++slot) - colReaders.emplace_back(ds.GetColumnReaders(slot, colName, typeid(T))); + colReaders.emplace_back( + ROOT::Internal::RDF::CreateColumnReader(ds, slot, colName, typeid(T), /*treeReader*/ nullptr)); } lm.AddDataSourceColumnReaders(colName, std::move(colReaders), typeid(T)); @@ -540,7 +549,7 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str using ColTypes_t = typename TTraits::CallableTraits::arg_types; auto ds = lm->GetDataSource(); - if (ds != nullptr) + if (ds != nullptr && colsPtr) AddDSColumns(cols, *lm, *ds, ColTypes_t(), *colRegister); // will never actually be used (trumped by jittedDefine->GetTypeName()), but we set it to something meaningful @@ -800,8 +809,8 @@ template using InnerValueType_t = typename InnerValueType::type; std::pair, std::vector> -AddSizeBranches(const std::vector &branches, TTree *tree, std::vector &&colsWithoutAliases, - std::vector &&colsWithAliases); +AddSizeBranches(const std::vector &branches, ROOT::RDF::RDataSource *ds, + std::vector &&colsWithoutAliases, std::vector &&colsWithAliases); void RemoveDuplicates(ColumnNames_t &columnNames); diff --git a/tree/dataframe/inc/ROOT/RDF/RInterface.hxx b/tree/dataframe/inc/ROOT/RDF/RInterface.hxx index 7c95ff9e7fbde..7a03e9a466926 100644 --- a/tree/dataframe/inc/ROOT/RDF/RInterface.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RInterface.hxx @@ -345,7 +345,7 @@ public: // For now disable this functionality in case of an empty data source and // the column name was not defined previously. if (ROOT::Internal::RDF::GetDataSourceLabel(*this) == "EmptyDS") - GetValidatedColumnNames(1, columns); + throw std::runtime_error("Unknown column: \"" + std::string(column) + "\""); using F_t = RDFDetail::RFilterWithMissingValues; auto filterPtr = std::make_shared(/*discardEntry*/ true, fProxiedPtr, fColRegister, columns); CheckAndFillDSColumns(columns, TTraits::TypeList{}); @@ -396,7 +396,7 @@ public: // For now disable this functionality in case of an empty data source and // the column name was not defined previously. if (ROOT::Internal::RDF::GetDataSourceLabel(*this) == "EmptyDS") - GetValidatedColumnNames(1, columns); + throw std::runtime_error("Unknown column: \"" + std::string(column) + "\""); using F_t = RDFDetail::RFilterWithMissingValues; auto filterPtr = std::make_shared(/*discardEntry*/ false, fProxiedPtr, fColRegister, columns); CheckAndFillDSColumns(columns, TTraits::TypeList{}); @@ -1329,9 +1329,8 @@ public: auto colListNoAliases = GetValidatedColumnNames(colListNoPoundSizes.size(), colListNoPoundSizes); RDFInternal::CheckForDuplicateSnapshotColumns(colListNoAliases); // like validCols but with missing size branches required by array branches added in the right positions - const auto pairOfColumnLists = - RDFInternal::AddSizeBranches(fLoopManager->GetBranchNames(), fLoopManager->GetTree(), - std::move(colListNoAliases), std::move(colListNoPoundSizes)); + const auto pairOfColumnLists = RDFInternal::AddSizeBranches( + fLoopManager->GetBranchNames(), GetDataSource(), std::move(colListNoAliases), std::move(colListNoPoundSizes)); const auto &colListNoAliasesWithSizeBranches = pairOfColumnLists.first; const auto &colListWithAliasesAndSizeBranches = pairOfColumnLists.second; @@ -1346,7 +1345,7 @@ public: if (options.fOutputFormat == ESnapshotOutputFormat::kRNTuple) { #ifdef R__HAS_ROOT7 - if (fLoopManager->GetTree()) { + if (RDFInternal::GetDataSourceLabel(*this) == "TTreeDS") { throw std::runtime_error("Snapshotting from TTree to RNTuple is not yet supported. The current recommended " "way to convert TTrees to RNTuple is through the RNTupleImporter."); } @@ -1358,7 +1357,7 @@ public: auto snapHelperArgs = std::make_shared(RDFInternal::SnapshotHelperArgs{ std::string(filename), std::string(dirname), std::string(treename), colListWithAliasesAndSizeBranches, - options, newRDF->GetLoopManager(), true /* fToNTuple */}); + options, newRDF->GetLoopManager(), GetDataSource(), true /* fToNTuple */}); // The Snapshot helper will use colListNoAliasesWithSizeBranches (with aliases resolved) as input columns, and // colListWithAliasesAndSizeBranches (still with aliases in it, passed through snapHelperArgs) as output column @@ -1380,21 +1379,15 @@ public: "RSnapshotOptions. Note that this current default behaviour might change in the future."); } - // The CreateLMFromTTree function by default opens the file passed as input - // to check for the presence of the TTree inside. But at this moment the - // filename we are using here corresponds to a file which does not exist yet, - // i.e. the output file of the Snapshot call. Thus, checkFile=false will - // prevent the function from trying to open a non-existent file. - auto newRDF = std::make_shared>(ROOT::Detail::RDF::CreateLMFromTTree( - fullTreeName, filename, /*defaultColumns=*/colListNoPoundSizes, /*checkFile=*/false)); + // We create an RLoopManager without a data source. This needs to be initialised when the output TTree dataset + // has actually been created and written to TFile, i.e. at the end of the Snapshot execution. + auto newRDF = std::make_shared>( + std::make_shared(colListNoAliasesWithSizeBranches)); auto snapHelperArgs = std::make_shared(RDFInternal::SnapshotHelperArgs{ std::string(filename), std::string(dirname), std::string(treename), colListWithAliasesAndSizeBranches, - options, nullptr, false /* fToNTuple */}); + options, newRDF->GetLoopManager(), GetDataSource(), false /* fToRNTuple */}); - // The Snapshot helper will use colListNoAliasesWithSizeBranches (with aliases resolved) as input columns, and - // colListWithAliasesAndSizeBranches (still with aliases in it, passed through snapHelperArgs) as output column - // names. resPtr = CreateAction( colListNoAliasesWithSizeBranches, newRDF, snapHelperArgs, fProxiedPtr, colListNoAliasesWithSizeBranches.size(), options.fVector2RVec); @@ -1426,7 +1419,7 @@ public: auto *tree = fLoopManager->GetTree(); const auto treeBranchNames = tree != nullptr ? ROOT::Internal::TreeUtils::GetTopLevelBranchNames(*tree) : ColumnNames_t{}; - const auto dsColumns = GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{}; + const auto dsColumns = GetDataSource() ? ROOT::Internal::RDF::GetTopLevelFieldNames(*GetDataSource()) : ColumnNames_t{}; // Ignore R_rdf_sizeof_* columns coming from datasources: we don't want to Snapshot those ColumnNames_t dsColumnsWithoutSizeColumns; std::copy_if(dsColumns.begin(), dsColumns.end(), std::back_inserter(dsColumnsWithoutSizeColumns), @@ -3242,7 +3235,7 @@ private: if (options.fOutputFormat == ESnapshotOutputFormat::kRNTuple) { #ifdef R__HAS_ROOT7 - if (fLoopManager->GetTree()) { + if (RDFInternal::GetDataSourceLabel(*this) == "TTreeDS") { throw std::runtime_error("Snapshotting from TTree to RNTuple is not yet supported. The current recommended " "way to convert TTrees to RNTuple is through the RNTupleImporter."); } @@ -3252,7 +3245,7 @@ private: auto snapHelperArgs = std::make_shared(RDFInternal::SnapshotHelperArgs{ std::string(filename), std::string(dirname), std::string(treename), columnListWithoutSizeColumns, options, - newRDF->GetLoopManager(), true /* fToRNTuple */}); + newRDF->GetLoopManager(), GetDataSource(), true /* fToRNTuple */}); // The Snapshot helper will use validCols (with aliases resolved) as input columns, and // columnListWithoutSizeColumns (still with aliases in it, passed through snapHelperArgs) as output column @@ -3273,17 +3266,14 @@ private: "RSnapshotOptions. Note that this current default behaviour might change in the future."); } - // The CreateLMFromTTree function by default opens the file passed as input - // to check for the presence of the TTree inside. But at this moment the - // filename we are using here corresponds to a file which does not exist yet, - // i.e. the output file of the Snapshot call. Thus, checkFile=false will - // prevent the function from trying to open a non-existent file. - auto newRDF = std::make_shared>(ROOT::Detail::RDF::CreateLMFromTTree( - fullTreeName, filename, /*defaultColumns=*/columnListWithoutSizeColumns, /*checkFile=*/false)); - - auto snapHelperArgs = std::make_shared( - RDFInternal::SnapshotHelperArgs{std::string(filename), std::string(dirname), std::string(treename), - columnListWithoutSizeColumns, options, nullptr, false /* fToRNTuple */}); + // We create an RLoopManager without a data source. This needs to be initialised when the output TTree dataset + // has actually been created and written to TFile, i.e. at the end of the Snapshot execution. + auto newRDF = + std::make_shared>(std::make_shared(columnListWithoutSizeColumns)); + + auto snapHelperArgs = std::make_shared(RDFInternal::SnapshotHelperArgs{ + std::string(filename), std::string(dirname), std::string(treename), columnListWithoutSizeColumns, options, + newRDF->GetLoopManager(), GetDataSource(), false /* fToRNTuple */}); // The Snapshot helper will use validCols (with aliases resolved) as input columns, and // columnListWithoutSizeColumns (still with aliases in it, passed through snapHelperArgs) as output column diff --git a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx index cb79330ec8faa..f606df2d3a78d 100644 --- a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx @@ -44,6 +44,7 @@ class RDataSource; } // ns RDF namespace Internal { +class RSlotStack; namespace RDF { std::vector GetBranchNames(TTree &t, bool allowDuplicates = true); @@ -213,7 +214,6 @@ class RLoopManager : public RNodeBase { public: RLoopManager(const ColumnNames_t &defaultColumns = {}); RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches); - RLoopManager(std::unique_ptr tree, const ColumnNames_t &defaultBranches); RLoopManager(ULong64_t nEmptyEntries); RLoopManager(std::unique_ptr ds, const ColumnNames_t &defaultBranches); RLoopManager(ROOT::RDF::Experimental::RDatasetSpec &&spec); @@ -254,12 +254,14 @@ public: void ToJitExec(const std::string &) const; void RegisterCallback(ULong64_t everyNEvents, std::function &&f); unsigned int GetNRuns() const { return fNRuns; } - bool HasDataSourceColumnReaders(const std::string &col, const std::type_info &ti) const; - void AddDataSourceColumnReaders(const std::string &col, std::vector> &&readers, + bool HasDataSourceColumnReaders(std::string_view col, const std::type_info &ti) const; + void AddDataSourceColumnReaders(std::string_view col, std::vector> &&readers, const std::type_info &ti); - RColumnReaderBase *AddTreeColumnReader(unsigned int slot, const std::string &col, + RColumnReaderBase *AddTreeColumnReader(unsigned int slot, std::string_view col, std::unique_ptr &&reader, const std::type_info &ti); - RColumnReaderBase *GetDatasetColumnReader(unsigned int slot, const std::string &col, const std::type_info &ti) const; + RColumnReaderBase *GetDatasetColumnReader(unsigned int slot, std::string_view col, const std::type_info &ti) const; + RColumnReaderBase *AddDataSourceColumnReader(unsigned int slot, std::string_view col, const std::type_info &ti, + TTreeReader *treeReader); /// End of recursive chain of calls, does nothing void AddFilterName(std::vector &) final {} @@ -312,6 +314,13 @@ public: { return fSuppressErrorsForMissingBranches; } + + /// The task run by every thread on the input entry range, for the generic RDataSource. + void DataSourceThreadTask(const std::pair &entryRange, ROOT::Internal::RSlotStack &slotStack, + std::atomic &entryCount); + /// The task run by every thread on an entry range (known by the input TTreeReader), for the TTree data source. + void + TTreeThreadTask(TTreeReader &treeReader, ROOT::Internal::RSlotStack &slotStack, std::atomic &entryCount); }; /// \brief Create an RLoopManager that reads a TChain. diff --git a/tree/dataframe/inc/ROOT/RDF/Utils.hxx b/tree/dataframe/inc/ROOT/RDF/Utils.hxx index 896139182df89..5b168f53d071f 100644 --- a/tree/dataframe/inc/ROOT/RDF/Utils.hxx +++ b/tree/dataframe/inc/ROOT/RDF/Utils.hxx @@ -123,10 +123,13 @@ struct IsVector_t : public std::false_type {}; template struct IsVector_t> : public std::true_type {}; +std::string GetBranchOrLeafTypeName(TTree &t, const std::string &colName); + const std::type_info &TypeName2TypeID(const std::string &name); std::string TypeID2TypeName(const std::type_info &id); +std::string GetTypeNameWithOpts(const ROOT::RDF::RDataSource &df, std::string_view colName, bool vector2RVec); std::string ColumnName2ColumnTypeName(const std::string &colName, TTree *, RDataSource *, RDefineBase *, bool vector2RVec = true); @@ -314,6 +317,14 @@ struct CallGuaranteedOrder { f(std::forward(args)...); } }; + +template +auto MakeAliasedSharedPtr(T *rawPtr) +{ + const static std::shared_ptr fgRawPtrCtrlBlock; + return std::shared_ptr(fgRawPtrCtrlBlock, rawPtr); +} + } // end NS RDF } // end NS Internal } // end NS ROOT diff --git a/tree/dataframe/inc/ROOT/RDataSource.hxx b/tree/dataframe/inc/ROOT/RDataSource.hxx index 288f05913c844..c457f99e1b503 100644 --- a/tree/dataframe/inc/ROOT/RDataSource.hxx +++ b/tree/dataframe/inc/ROOT/RDataSource.hxx @@ -18,13 +18,28 @@ #include // std::transform #include +#include +#include #include #include +#include +#include #include +#include + +// Need to fwd-declare TTreeReader for CreateColumnReader +class TTreeReader; +namespace ROOT::Detail::RDF { +class RLoopManager; +} namespace ROOT { namespace RDF { class RDataSource; +class RSampleInfo; +namespace Experimental { +class RSample; +} } } @@ -71,6 +86,23 @@ public: }; } // ns TDS + +namespace RDF { +std::string GetTypeNameWithOpts(const ROOT::RDF::RDataSource &ds, std::string_view colName, bool vector2RVec); +const std::vector &GetTopLevelFieldNames(const ROOT::RDF::RDataSource &ds); +const std::vector &GetColumnNamesNoDuplicates(const ROOT::RDF::RDataSource &ds); +void CallInitializeWithOpts(ROOT::RDF::RDataSource &ds, const std::set &suppressErrorsForMissingColumns); +std::string DescribeDataset(ROOT::RDF::RDataSource &ds); +ROOT::RDF::RSampleInfo +CreateSampleInfo(const ROOT::RDF::RDataSource &ds, + const std::unordered_map &sampleMap); +void RunFinalChecks(const ROOT::RDF::RDataSource &ds, bool nodesLeftNotRun); +void ProcessMT(ROOT::RDF::RDataSource &ds, ROOT::Detail::RDF::RLoopManager &lm); +std::unique_ptr +CreateColumnReader(ROOT::RDF::RDataSource &ds, unsigned int slot, std::string_view col, const std::type_info &tid, + TTreeReader *treeReader); +} // namespace RDF + } // ns Internal namespace RDF { @@ -117,6 +149,57 @@ protected: unsigned int fNSlots{}; + std::optional> fGlobalEntryRange{}; + + friend std::string ROOT::Internal::RDF::GetTypeNameWithOpts(const RDataSource &, std::string_view, bool); + virtual std::string GetTypeNameWithOpts(std::string_view colName, bool) const { return GetTypeName(colName); } + + friend const std::vector &ROOT::Internal::RDF::GetTopLevelFieldNames(const ROOT::RDF::RDataSource &); + virtual const std::vector &GetTopLevelFieldNames() const { return GetColumnNames(); } + + friend const std::vector & + ROOT::Internal::RDF::GetColumnNamesNoDuplicates(const ROOT::RDF::RDataSource &); + virtual const std::vector &GetColumnNamesNoDuplicates() const { return GetColumnNames(); } + + friend void ROOT::Internal::RDF::CallInitializeWithOpts(ROOT::RDF::RDataSource &, const std::set &); + virtual void InitializeWithOpts(const std::set &) { Initialize(); } + + friend std::string ROOT::Internal::RDF::DescribeDataset(ROOT::RDF::RDataSource &); + virtual std::string DescribeDataset() { return "Dataframe from datasource " + GetLabel(); } + + friend ROOT::RDF::RSampleInfo + ROOT::Internal::RDF::CreateSampleInfo(const ROOT::RDF::RDataSource &, + const std::unordered_map &); + virtual ROOT::RDF::RSampleInfo + CreateSampleInfo(const std::unordered_map &) const; + + friend void ROOT::Internal::RDF::RunFinalChecks(const ROOT::RDF::RDataSource &, bool); + virtual void RunFinalChecks(bool) const {} + + friend void ROOT::Internal::RDF::ProcessMT(RDataSource &, ROOT::Detail::RDF::RLoopManager &); + virtual void ProcessMT(ROOT::Detail::RDF::RLoopManager &); + + friend std::unique_ptr + ROOT::Internal::RDF::CreateColumnReader(ROOT::RDF::RDataSource &, unsigned int, std::string_view, + const std::type_info &, TTreeReader *); + /** + * \brief Creates a column reader for the requested column + * + * In the general case, this is just a redirect to the right GetColumnReaders overload. The signature notably also + * has a TTreeReader * parameter. This is currently necessary to still allow the TTree-based MT scheduling via + * TTreeProcessorMT. We use the TTreeProcessorMT::Process method to launch the same kernel across all threads. In + * each thread task, TTreeProcessorMT creates a thread-local instance of a TTreeReader which is going to read the + * range of events assigned to that task. That TTreeReader instance is what is passed to this method whenever a + * column reader needs to be created in a thread task. In the future this method might be removed by either allowing + * to request a handle to the thread-local TTreeReader instance programmatically from the TTreeProcessorMT, or + * refactoring the TTreeProcessorMT scheduling into RTTreeDS altogether. + */ + virtual std::unique_ptr + CreateColumnReader(unsigned int slot, std::string_view col, const std::type_info &tid, TTreeReader *) + { + return GetColumnReaders(slot, col, tid); + } + public: RDataSource() = default; // Rule of five @@ -242,6 +325,13 @@ public: /// Concrete datasources can override the default implementation. virtual std::string GetLabel() { return "Custom Datasource"; } + /// \brief Restrict processing to a [begin, end) range of entries. + /// \param entryRange The range of entries to process. + virtual void SetGlobalEntryRange(std::pair entryRange) + { + fGlobalEntryRange = std::move(entryRange); + }; + protected: /// type-erased vector of pointers to pointers to column values - one per slot virtual Record_t GetColumnReadersImpl(std::string_view name, const std::type_info &) = 0; diff --git a/tree/dataframe/inc/ROOT/RTTreeDS.hxx b/tree/dataframe/inc/ROOT/RTTreeDS.hxx new file mode 100644 index 0000000000000..a93880be5cbcc --- /dev/null +++ b/tree/dataframe/inc/ROOT/RTTreeDS.hxx @@ -0,0 +1,154 @@ +/** + \file ROOT/RTTreeDS.hxx + \ingroup dataframe + \author Vincenzo Eduardo Padulano + \date 2024-12 +*/ + +/************************************************************************* + * Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. * + * All rights reserved. * + * * + * For the licensing terms see $ROOTSYS/LICENSE. * + * For the list of contributors see $ROOTSYS/README/CREDITS. * + *************************************************************************/ + +#ifndef ROOT_INTERNAL_RDF_RTTREEDS +#define ROOT_INTERNAL_RDF_RTTREEDS + +#include "ROOT/RDataSource.hxx" + +#include +#include +#include +#include +#include + +// Begin forward decls + +namespace ROOT { +class RDataFrame; +} + +namespace ROOT::Detail::RDF { +class RLoopManager; +} + +namespace ROOT::RDF { +class RSampleInfo; +} + +namespace ROOT::RDF::Experimental { +class RSample; +} + +namespace ROOT::TreeUtils { +struct RFriendInfo; +} + +class TChain; +class TDirectory; +class TTree; +class TTreeReader; + +// End forward decls + +namespace ROOT::Internal::RDF { + +class RTTreeDS final : public ROOT::RDF::RDataSource { + std::vector fBranchNamesWithDuplicates{}; + std::vector fBranchNamesWithoutDuplicates{}; + std::vector fTopLevelBranchNames{}; + + std::shared_ptr fTree; + + std::unique_ptr fTreeReader; + + std::vector> fFriends; + + ROOT::RDF::RSampleInfo + CreateSampleInfo(const std::unordered_map &sampleMap) const final; + + void RunFinalChecks(bool nodesLeftNotRun) const final; + + void Setup(std::shared_ptr &&tree, const ROOT::TreeUtils::RFriendInfo *friendInfo = nullptr); + + std::vector> GetTTreeEntryRange(TTree &tree); + std::vector> GetTChainEntryRange(TChain &chain); + +public: + RTTreeDS(std::shared_ptr tree); + RTTreeDS(std::shared_ptr tree, const ROOT::TreeUtils::RFriendInfo &friendInfo); + RTTreeDS(std::string_view treeName, TDirectory *dirPtr); + RTTreeDS(std::string_view treeName, std::string_view fileNameGlob); + RTTreeDS(std::string_view treeName, const std::vector &fileNameGlobs); + + // Rule of five + RTTreeDS(const RTTreeDS &) = delete; + RTTreeDS &operator=(const RTTreeDS &) = delete; + RTTreeDS(RTTreeDS &&) = delete; + RTTreeDS &operator=(RTTreeDS &&) = delete; + ~RTTreeDS() final; // Define destructor where data member types are defined + + void Initialize() final; + + void Finalize() final; + + std::vector> GetEntryRanges() final; + + const std::vector &GetColumnNames() const final { return fBranchNamesWithDuplicates; } + + bool HasColumn(std::string_view colName) const final + { + return std::find(fBranchNamesWithDuplicates.begin(), fBranchNamesWithDuplicates.end(), colName) != + fBranchNamesWithDuplicates.end(); + } + + std::string GetTypeName(std::string_view colName) const final; + + std::string GetTypeNameWithOpts(std::string_view colName, bool vector2RVec) const final; + + bool SetEntry(unsigned int, ULong64_t entry) final; + + Record_t GetColumnReadersImpl(std::string_view /* name */, const std::type_info & /* ti */) final + { + // This datasource uses the newer GetColumnReaders() API + return {}; + } + + std::unique_ptr + GetColumnReaders(unsigned int, std::string_view, const std::type_info &) final + { + // This data source creates column readers via CreateColumnReader + throw std::runtime_error("GetColumnReaders should not be called on this data source, something wrong happened!"); + } + + std::unique_ptr CreateColumnReader(unsigned int slot, std::string_view col, + const std::type_info &tid, + TTreeReader *treeReader) final; + + std::string GetLabel() final { return "TTreeDS"; } + + TTree *GetTree(); + + const std::vector &GetTopLevelFieldNames() const final { return fTopLevelBranchNames; } + + const std::vector &GetColumnNamesNoDuplicates() const final { return fBranchNamesWithoutDuplicates; } + + void InitializeWithOpts(const std::set &suppressErrorsForMissingBranches) final; + + std::string DescribeDataset() final; + + std::string AsString() final { return "TTree data source"; } + + std::size_t GetNFiles() const final; + + void ProcessMT(ROOT::Detail::RDF::RLoopManager &lm) final; +}; + +ROOT::RDataFrame FromTTree(std::string_view treeName, std::string_view fileNameGlob); +ROOT::RDataFrame FromTTree(std::string_view treeName, const std::vector &fileNameGlobs); + +} // namespace ROOT::Internal::RDF + +#endif diff --git a/tree/dataframe/src/RDFColumnReaderUtils.cxx b/tree/dataframe/src/RDFColumnReaderUtils.cxx index e4a3dc6191be2..686d56f783acd 100644 --- a/tree/dataframe/src/RDFColumnReaderUtils.cxx +++ b/tree/dataframe/src/RDFColumnReaderUtils.cxx @@ -1,36 +1,4 @@ #include "ROOT/RDF/ColumnReaderUtils.hxx" -#include "ROOT/RDF/RTreeColumnReader.hxx" - -namespace { -std::tuple -GetCollectionInfo(const std::string &typeName) -{ - const auto beginType = typeName.substr(0, typeName.find_first_of('<') + 1); - - // Find TYPE from ROOT::RVec - if (auto pos = beginType.find("RVec<"); pos != std::string::npos) { - const auto begin = pos + 5; - const auto end = typeName.find_last_of('>'); - const auto innerTypeName = typeName.substr(begin, end - begin); - if (innerTypeName == "bool") - return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVecBool}; - else - return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; - } - - // Find TYPE from std::array - if (auto pos = beginType.find("array<"); pos != std::string::npos) { - const auto begin = pos + 6; - const auto end = typeName.find_last_of('>'); - const auto arrTemplArgs = typeName.substr(begin, end - begin); - const auto lastComma = arrTemplArgs.find_last_of(','); - return {true, arrTemplArgs.substr(0, lastComma), - ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdArray}; - } - - return {false, "", ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; -} -} // namespace ROOT::Detail::RDF::RColumnReaderBase * ROOT::Internal::RDF::GetColumnReader(unsigned int slot, ROOT::Detail::RDF::RColumnReaderBase *defineOrVariationReader, @@ -46,21 +14,5 @@ ROOT::Internal::RDF::GetColumnReader(unsigned int slot, ROOT::Detail::RDF::RColu if (datasetColReader != nullptr) return datasetColReader; - assert(treeReader != nullptr && - "We could not find a reader for this column, this should never happen at this point."); - - // Make a RTreeColumnReader for this column and insert it in RLoopManager's map - auto createColReader = [&]() -> std::unique_ptr { - if (ti == typeid(void)) - return std::make_unique(*treeReader, colName); - - const auto typeName = ROOT::Internal::RDF::TypeID2TypeName(ti); - if (auto &&[toConvert, innerTypeName, collType] = GetCollectionInfo(typeName); toConvert) - return std::make_unique(*treeReader, colName, - innerTypeName, collType); - else - return std::make_unique(*treeReader, colName, typeName); - }; - - return lm.AddTreeColumnReader(slot, std::string(colName), createColReader(), ti); + return lm.AddDataSourceColumnReader(slot, colName, ti, treeReader); } diff --git a/tree/dataframe/src/RDFInterfaceUtils.cxx b/tree/dataframe/src/RDFInterfaceUtils.cxx index ae88fbb86d4b0..0a7ef06320736 100644 --- a/tree/dataframe/src/RDFInterfaceUtils.cxx +++ b/tree/dataframe/src/RDFInterfaceUtils.cxx @@ -9,6 +9,7 @@ *************************************************************************/ #include +#include #include #include #include @@ -997,9 +998,12 @@ void CheckForDuplicateSnapshotColumns(const ColumnNames_t &cols) /// Return copies of colsWithoutAliases and colsWithAliases with size branches for variable-sized array branches added /// in the right positions (i.e. before the array branches that need them). std::pair, std::vector> -AddSizeBranches(const std::vector &branches, TTree *tree, std::vector &&colsWithoutAliases, - std::vector &&colsWithAliases) +AddSizeBranches(const std::vector &branches, ROOT::RDF::RDataSource *ds, + std::vector &&colsWithoutAliases, std::vector &&colsWithAliases) { + TTree *tree{}; + if (auto treeDS = dynamic_cast(ds)) + tree = treeDS->GetTree(); if (!tree) // nothing to do return {std::move(colsWithoutAliases), std::move(colsWithAliases)}; diff --git a/tree/dataframe/src/RDFUtils.cxx b/tree/dataframe/src/RDFUtils.cxx index 0c0ff777a7fa1..6902ad410ce9a 100644 --- a/tree/dataframe/src/RDFUtils.cxx +++ b/tree/dataframe/src/RDFUtils.cxx @@ -12,6 +12,8 @@ #include "ROOT/RDataSource.hxx" #include "ROOT/RDF/RDefineBase.hxx" #include "ROOT/RDF/RLoopManager.hxx" +#include "ROOT/RDF/RSample.hxx" +#include "ROOT/RDF/RSampleInfo.hxx" #include "ROOT/RDF/Utils.hxx" #include "ROOT/RLogger.hxx" #include "RtypesCore.h" @@ -235,7 +237,7 @@ std::string ColumnName2ColumnTypeName(const std::string &colName, TTree *tree, R if (define) { colType = define->GetTypeName(); } else if (ds && ds->HasColumn(colName)) { - colType = ds->GetTypeName(colName); + colType = ROOT::Internal::RDF::GetTypeNameWithOpts(*ds, colName, vector2RVec); } else if (tree) { colType = GetBranchOrLeafTypeName(*tree, colName); if (vector2RVec && TClassEdit::IsSTLCont(colType) == ROOT::ESTLType::kSTLvector) { @@ -461,3 +463,54 @@ auto RStringCache::Insert(const std::string &string) -> decltype(fStrings)::cons } // end NS RDF } // end NS Internal } // end NS ROOT + +std::string +ROOT::Internal::RDF::GetTypeNameWithOpts(const ROOT::RDF::RDataSource &df, std::string_view colName, bool vector2RVec) +{ + return df.GetTypeNameWithOpts(colName, vector2RVec); +} + +const std::vector &ROOT::Internal::RDF::GetTopLevelFieldNames(const ROOT::RDF::RDataSource &df) +{ + return df.GetTopLevelFieldNames(); +} + +const std::vector &ROOT::Internal::RDF::GetColumnNamesNoDuplicates(const ROOT::RDF::RDataSource &df) +{ + return df.GetColumnNamesNoDuplicates(); +} + +void ROOT::Internal::RDF::CallInitializeWithOpts(ROOT::RDF::RDataSource &ds, + const std::set &suppressErrorsForMissingColumns) +{ + ds.InitializeWithOpts(suppressErrorsForMissingColumns); +} + +std::string ROOT::Internal::RDF::DescribeDataset(ROOT::RDF::RDataSource &ds) +{ + return ds.DescribeDataset(); +} + +ROOT::RDF::RSampleInfo ROOT::Internal::RDF::CreateSampleInfo( + const ROOT::RDF::RDataSource &ds, + const std::unordered_map &sampleMap) +{ + return ds.CreateSampleInfo(sampleMap); +} + +void ROOT::Internal::RDF::RunFinalChecks(const ROOT::RDF::RDataSource &ds, bool nodesLeftNotRun) +{ + ds.RunFinalChecks(nodesLeftNotRun); +} + +void ROOT::Internal::RDF::ProcessMT(ROOT::RDF::RDataSource &ds, ROOT::Detail::RDF::RLoopManager &lm) +{ + ds.ProcessMT(lm); +} + +std::unique_ptr +ROOT::Internal::RDF::CreateColumnReader(ROOT::RDF::RDataSource &ds, unsigned int slot, std::string_view col, + const std::type_info &tid, TTreeReader *treeReader) +{ + return ds.CreateColumnReader(slot, col, tid, treeReader); +} diff --git a/tree/dataframe/src/RDataFrame.cxx b/tree/dataframe/src/RDataFrame.cxx index ee37af91a33e1..877becd9a996f 100644 --- a/tree/dataframe/src/RDataFrame.cxx +++ b/tree/dataframe/src/RDataFrame.cxx @@ -11,6 +11,7 @@ #include "ROOT/InternalTreeUtils.hxx" #include "ROOT/RDataFrame.hxx" #include "ROOT/RDataSource.hxx" +#include "ROOT/RTTreeDS.hxx" #include "ROOT/RDF/RDatasetSpec.hxx" #include "ROOT/RDF/RInterface.hxx" #include "ROOT/RDF/RLoopManager.hxx" @@ -1789,19 +1790,9 @@ using ColumnNamesPtr_t = std::shared_ptr; /// booking of actions or transformations. /// \note see ROOT::RDF::RInterface for the documentation of the methods available. RDataFrame::RDataFrame(std::string_view treeName, TDirectory *dirPtr, const ColumnNames_t &defaultColumns) - : RInterface(std::make_shared(nullptr, defaultColumns)) + : RInterface(std::make_shared( + std::make_unique(treeName, dirPtr), defaultColumns)) { - if (!dirPtr) { - auto msg = "Invalid TDirectory!"; - throw std::runtime_error(msg); - } - const std::string treeNameInt(treeName); - auto tree = static_cast(dirPtr->Get(treeNameInt.c_str())); - if (!tree) { - auto msg = "Tree \"" + treeNameInt + "\" cannot be found!"; - throw std::runtime_error(msg); - } - GetProxiedPtr()->SetTree(std::shared_ptr(tree, [](TTree *) {})); } //////////////////////////////////////////////////////////////////////////// diff --git a/tree/dataframe/src/RDataSource.cxx b/tree/dataframe/src/RDataSource.cxx new file mode 100644 index 0000000000000..702593c283ef8 --- /dev/null +++ b/tree/dataframe/src/RDataSource.cxx @@ -0,0 +1,49 @@ +#include +#include +#include +#include + +#ifdef R__USE_IMT +#include +#include +#endif + +ROOT::RDF::RSampleInfo ROOT::RDF::RDataSource::CreateSampleInfo( + const std::unordered_map &) const +{ + // Currently not implemented for the generic data source, only works correctly for TTree. + // TODO: Implement the feature also for the generic data source. + return ROOT::RDF::RSampleInfo{}; +} + +void ROOT::RDF::RDataSource::ProcessMT(ROOT::Detail::RDF::RLoopManager &lm) +{ +#ifdef R__USE_IMT + ROOT::Internal::RSlotStack slotStack(fNSlots); + std::atomic entryCount(0ull); + ROOT::TThreadExecutor pool; + + auto ranges = GetEntryRanges(); + while (!ranges.empty()) { + pool.Foreach( + [&lm, &slotStack, &entryCount](const std::pair &range) { + lm.DataSourceThreadTask(range, slotStack, entryCount); + }, + ranges); + ranges = GetEntryRanges(); + } + + if (fGlobalEntryRange.has_value()) { + auto &&[begin, end] = fGlobalEntryRange.value(); + auto &&processedEntries = entryCount.load(); + if ((end - begin) > processedEntries) { + Warning("RDataFrame::Run", + "RDataFrame stopped processing after %lld entries, whereas an entry range (begin=%lld,end=%lld) was " + "requested. Consider adjusting the end value of the entry range to a maximum of %lld.", + processedEntries, begin, end, begin + processedEntries); + } + } +#else + (void)lm; +#endif +} diff --git a/tree/dataframe/src/RInterface.cxx b/tree/dataframe/src/RInterface.cxx index 655ed57ee00d7..05ac848fc54cd 100644 --- a/tree/dataframe/src/RInterface.cxx +++ b/tree/dataframe/src/RInterface.cxx @@ -50,8 +50,8 @@ std::string ROOT::Internal::RDF::GetDataSourceLabel(const ROOT::RDF::RNode &node { if (node.fLoopManager->GetTree()) { return "TTreeDS"; - } else if (node.GetDataSource()) { - return node.GetDataSource()->GetLabel(); + } else if (auto ds = node.GetDataSource()) { + return ds->GetLabel(); } else { return "EmptyDS"; } diff --git a/tree/dataframe/src/RInterfaceBase.cxx b/tree/dataframe/src/RInterfaceBase.cxx index cab02602939b2..daaac32d588bc 100644 --- a/tree/dataframe/src/RInterfaceBase.cxx +++ b/tree/dataframe/src/RInterfaceBase.cxx @@ -43,74 +43,9 @@ unsigned int ROOT::RDF::RInterfaceBase::GetNFiles() std::string ROOT::RDF::RInterfaceBase::DescribeDataset() const { - // TTree/TChain as input - const auto tree = fLoopManager->GetTree(); - if (tree) { - const auto treeName = tree->GetName(); - const auto isTChain = dynamic_cast(tree) ? true : false; - const auto treeType = isTChain ? "TChain" : "TTree"; - const auto isInMemory = !isTChain && !tree->GetCurrentFile() ? true : false; - const auto friendInfo = ROOT::Internal::TreeUtils::GetFriendInfo(*tree); - const auto hasFriends = friendInfo.fFriendNames.empty() ? false : true; - std::stringstream ss; - ss << "Dataframe from " << treeType; - if (*treeName != 0) { - ss << " " << treeName; - } - if (isInMemory) { - ss << " (in-memory)"; - } else { - const auto files = ROOT::Internal::TreeUtils::GetFileNamesFromTree(*tree); - const auto numFiles = files.size(); - if (numFiles == 1) { - ss << " in file " << files[0]; - } else { - ss << " in files\n"; - for (auto i = 0u; i < numFiles; i++) { - ss << " " << files[i]; - if (i < numFiles - 1) - ss << '\n'; - } - } - } - if (hasFriends) { - const auto numFriends = friendInfo.fFriendNames.size(); - if (numFriends == 1) { - ss << "\nwith friend\n"; - } else { - ss << "\nwith friends\n"; - } - for (auto i = 0u; i < numFriends; i++) { - const auto nameAlias = friendInfo.fFriendNames[i]; - const auto files = friendInfo.fFriendFileNames[i]; - const auto numFiles = files.size(); - const auto subnames = friendInfo.fFriendChainSubNames[i]; - ss << " " << nameAlias.first; - if (nameAlias.first != nameAlias.second) - ss << " (" << nameAlias.second << ")"; - // case: TTree as friend - if (numFiles == 1) { - ss << " " << files[0]; - } - // case: TChain as friend - else { - ss << '\n'; - for (auto j = 0u; j < numFiles; j++) { - ss << " " << subnames[j] << " " << files[j]; - if (j < numFiles - 1) - ss << '\n'; - } - } - if (i < numFriends - 1) - ss << '\n'; - } - } - return ss.str(); - } // Datasource as input - else if (auto dataSource = GetDataSource()) { - const auto datasourceLabel = dataSource->GetLabel(); - return "Dataframe from datasource " + datasourceLabel; + if (auto ds = GetDataSource()) { + return ROOT::Internal::RDF::DescribeDataset(*ds); } // Trivial/empty datasource else { @@ -168,8 +103,8 @@ ROOT::RDF::ColumnNames_t ROOT::RDF::RInterfaceBase::GetColumnNames() allColumns.emplace(bName); } - if (auto dataSource = GetDataSource()) { - for (const auto &s : dataSource->GetColumnNames()) { + if (auto ds = GetDataSource()) { + for (const auto &s : ROOT::Internal::RDF::GetColumnNamesNoDuplicates(*ds)) { if (s.rfind("R_rdf_sizeof", 0) != 0) allColumns.emplace(s); } @@ -361,7 +296,7 @@ bool ROOT::RDF::RInterfaceBase::HasColumn(std::string_view columnName) return true; } - if (GetDataSource() && GetDataSource()->HasColumn(columnName)) + if (auto ds = GetDataSource(); ds->HasColumn(columnName)) return true; return false; diff --git a/tree/dataframe/src/RLoopManager.cxx b/tree/dataframe/src/RLoopManager.cxx index 696964ecf8b20..9da08088d6d21 100644 --- a/tree/dataframe/src/RLoopManager.cxx +++ b/tree/dataframe/src/RLoopManager.cxx @@ -31,6 +31,8 @@ #include "TTreeReader.h" #include "TTree.h" // For MaxTreeSizeRAII. Revert when #6640 will be solved. +#include "ROOT/RTTreeDS.hxx" + #ifdef R__USE_IMT #include "ROOT/TThreadExecutor.hxx" #include "ROOT/TTreeProcessorMT.hxx" @@ -324,13 +326,13 @@ DatasetLogInfo TreeDatasetLogInfo(const TTreeReader &r, unsigned int slot) return {std::move(what), static_cast(entryRange.first), end, slot}; } -auto MakeDatasetColReadersKey(const std::string &colName, const std::type_info &ti) +auto MakeDatasetColReadersKey(std::string_view colName, const std::type_info &ti) { // We use a combination of column name and column type name as the key because in some cases we might end up // with concrete readers that use different types for the same column, e.g. std::vector and RVec here: // df.Sum>("stdVectorBranch"); // df.Sum("stdVectorBranch"); - return colName + ':' + ti.name(); + return std::string(colName) + ':' + ti.name(); } } // anonymous namespace @@ -377,25 +379,15 @@ ROOT::Detail::RDF::RLoopManager::RLoopManager(const ROOT::Detail::RDF::ColumnNam } RLoopManager::RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches) - : fTree(std::shared_ptr(tree, [](TTree *) {})), - fDefaultColumns(defaultBranches), - fNSlots(RDFInternal::GetNSlots()), - fLoopType(ROOT::IsImplicitMTEnabled() ? ELoopType::kROOTFilesMT : ELoopType::kROOTFiles), - fNewSampleNotifier(fNSlots), - fSampleInfos(fNSlots), - fDatasetColumnReaders(fNSlots) -{ -} - -RLoopManager::RLoopManager(std::unique_ptr tree, const ColumnNames_t &defaultBranches) - : fTree(std::move(tree)), - fDefaultColumns(defaultBranches), + : fDefaultColumns(defaultBranches), fNSlots(RDFInternal::GetNSlots()), - fLoopType(ROOT::IsImplicitMTEnabled() ? ELoopType::kROOTFilesMT : ELoopType::kROOTFiles), + fLoopType(ROOT::IsImplicitMTEnabled() ? ELoopType::kDataSourceMT : ELoopType::kDataSource), + fDataSource(std::make_unique(ROOT::Internal::RDF::MakeAliasedSharedPtr(tree))), fNewSampleNotifier(fNSlots), fSampleInfos(fNSlots), fDatasetColumnReaders(fNSlots) { + fDataSource->SetNSlots(fNSlots); } RLoopManager::RLoopManager(ULong64_t nEmptyEntries) @@ -422,7 +414,7 @@ RLoopManager::RLoopManager(std::unique_ptr ds, const ColumnNames_t RLoopManager::RLoopManager(ROOT::RDF::Experimental::RDatasetSpec &&spec) : fNSlots(RDFInternal::GetNSlots()), - fLoopType(ROOT::IsImplicitMTEnabled() ? ELoopType::kROOTFilesMT : ELoopType::kROOTFiles), + fLoopType(ROOT::IsImplicitMTEnabled() ? ELoopType::kDataSourceMT : ELoopType::kDataSource), fNewSampleNotifier(fNSlots), fSampleInfos(fNSlots), fDatasetColumnReaders(fNSlots) @@ -507,14 +499,11 @@ void RLoopManager::ChangeSpec(ROOT::RDF::Experimental::RDatasetSpec &&spec) #endif } } - SetTree(std::move(chain)); - - // Create friends from the specification and connect them to the main chain - const auto &friendInfo = spec.GetFriendInfo(); - fFriends = ROOT::Internal::TreeUtils::MakeFriends(friendInfo); - for (std::size_t i = 0ul; i < fFriends.size(); i++) { - const auto &thisFriendAlias = friendInfo.fFriendNames[i].second; - fTree->AddFriend(fFriends[i].get(), thisFriendAlias.c_str()); + fDataSource = std::make_unique(std::move(chain), spec.GetFriendInfo()); + fDataSource->SetNSlots(fNSlots); + for (unsigned int slot{}; slot < fNSlots; slot++) { + for (auto &v : fDatasetColumnReaders[slot]) + v.second.reset(); } } @@ -716,7 +705,10 @@ void RLoopManager::RunTreeReader() namespace { struct DSRunRAII { ROOT::RDF::RDataSource &fDS; - DSRunRAII(ROOT::RDF::RDataSource &ds) : fDS(ds) { fDS.Initialize(); } + DSRunRAII(ROOT::RDF::RDataSource &ds, const std::set &suppressErrorsForMissingColumns) : fDS(ds) + { + ROOT::Internal::RDF::CallInitializeWithOpts(fDS, suppressErrorsForMissingColumns); + } ~DSRunRAII() { fDS.Finalize(); } }; } // namespace @@ -724,9 +716,12 @@ struct DSRunRAII { struct ROOT::Internal::RDF::RDSRangeRAII { ROOT::Detail::RDF::RLoopManager &fLM; unsigned int fSlot; - RDSRangeRAII(ROOT::Detail::RDF::RLoopManager &lm, unsigned int slot, ULong64_t firstEntry) : fLM(lm), fSlot(slot) + TTreeReader *fTreeReader; + RDSRangeRAII(ROOT::Detail::RDF::RLoopManager &lm, unsigned int slot, ULong64_t firstEntry, + TTreeReader *treeReader = nullptr) + : fLM(lm), fSlot(slot), fTreeReader(treeReader) { - fLM.InitNodeSlots(nullptr, fSlot); + fLM.InitNodeSlots(fTreeReader, fSlot); fLM.GetDataSource()->InitSlot(fSlot, firstEntry); } ~RDSRangeRAII() { fLM.GetDataSource()->FinalizeSlot(fSlot); } @@ -736,11 +731,31 @@ struct ROOT::Internal::RDF::RDSRangeRAII { void RLoopManager::RunDataSource() { assert(fDataSource != nullptr); - DSRunRAII _{*fDataSource}; + // Shortcut if the entry range would result in not reading anything + if (fBeginEntry == fEndEntry) + return; + // Apply global entry range if necessary + if (fBeginEntry != 0 || fEndEntry != std::numeric_limits::max()) + fDataSource->SetGlobalEntryRange(std::make_pair(fBeginEntry, fEndEntry)); + // Initialize data source and book finalization + DSRunRAII _{*fDataSource, fSuppressErrorsForMissingBranches}; + // Ensure cleanup task is always called at the end. Notably, this also resets the column readers for those data + // sources that need it (currently only TTree). RCallCleanUpTask cleanup(*this); - auto ranges = fDataSource->GetEntryRanges(); - while (!ranges.empty() && fNStopsReceived < fNChildren) { - RDSRangeRAII __{*this, 0u, 0ull}; + + // Main event loop. We start with an empty vector of ranges because we need to initialize the nodes and the data + // source before the first call to GetEntryRanges, since it could trigger reading (currently only happens with + // TTree). + std::uint64_t processedEntries{}; + std::vector> ranges{}; + do { + + ROOT::Internal::RDF::RDSRangeRAII __{*this, 0u, 0ull}; + + ranges = fDataSource->GetEntryRanges(); + + fSampleInfos[0] = ROOT::Internal::RDF::CreateSampleInfo(*fDataSource, fSampleMap); + try { for (const auto &range : ranges) { const auto start = range.first; @@ -750,13 +765,24 @@ void RLoopManager::RunDataSource() if (fDataSource->SetEntry(0u, entry)) { RunAndCheckFilters(0u, entry); } + processedEntries++; } } } catch (...) { std::cerr << "RDataFrame::Run: event loop was interrupted\n"; throw; } - ranges = fDataSource->GetEntryRanges(); + + } while (!ranges.empty() && fNStopsReceived < fNChildren); + + ROOT::Internal::RDF::RunFinalChecks(*fDataSource, (fNStopsReceived < fNChildren)); + + if (fEndEntry != std::numeric_limits::max() && + static_cast(fEndEntry - fBeginEntry) > processedEntries) { + Warning("RDataFrame::Run", + "RDataFrame stopped processing after %lu entries, whereas an entry range (begin=%lld,end=%lld) was " + "requested. Consider adjusting the end value of the entry range to a maximum of %lld.", + processedEntries, fBeginEntry, fEndEntry, fBeginEntry + processedEntries); } } @@ -765,36 +791,17 @@ void RLoopManager::RunDataSourceMT() { #ifdef R__USE_IMT assert(fDataSource != nullptr); - ROOT::Internal::RSlotStack slotStack(fNSlots); - ROOT::TThreadExecutor pool; + // Shortcut if the entry range would result in not reading anything + if (fBeginEntry == fEndEntry) + return; + // Apply global entry range if necessary + if (fBeginEntry != 0 || fEndEntry != std::numeric_limits::max()) + fDataSource->SetGlobalEntryRange(std::make_pair(fBeginEntry, fEndEntry)); - // Each task works on a subrange of entries - auto runOnRange = [this, &slotStack](const std::pair &range) { - ROOT::Internal::RSlotStackRAII slotRAII(slotStack); - const auto slot = slotRAII.fSlot; - RDSRangeRAII _{*this, slot, range.first}; - RCallCleanUpTask cleanup(*this, slot); - const auto start = range.first; - const auto end = range.second; - R__LOG_DEBUG(0, RDFLogChannel()) << LogRangeProcessing({fDataSource->GetLabel(), start, end, slot}); - try { - for (auto entry = start; entry < end; ++entry) { - if (fDataSource->SetEntry(slot, entry)) { - RunAndCheckFilters(slot, entry); - } - } - } catch (...) { - std::cerr << "RDataFrame::Run: event loop was interrupted\n"; - throw; - } - }; + DSRunRAII _{*fDataSource, fSuppressErrorsForMissingBranches}; + + ROOT::Internal::RDF::ProcessMT(*fDataSource, *this); - DSRunRAII _{*fDataSource}; - auto ranges = fDataSource->GetEntryRanges(); - while (!ranges.empty()) { - pool.Foreach(runOnRange, ranges); - ranges = fDataSource->GetEntryRanges(); - } #endif // not implemented otherwise (never called) } @@ -930,7 +937,7 @@ void RLoopManager::CleanUpTask(TTreeReader *r, unsigned int slot) for (auto *ptr : fBookedDefines) ptr->FinalizeSlot(slot); - if (fLoopType == ELoopType::kROOTFiles || fLoopType == ELoopType::kROOTFilesMT) { + if (auto ds = GetDataSource(); ds && ds->GetLabel() == "TTreeDS") { // we are reading from a tree/chain and we need to re-create the RTreeColumnReaders at every task // because the TTreeReader object changes at every task for (auto &v : fDatasetColumnReaders[slot]) @@ -1194,20 +1201,24 @@ const ColumnNames_t &RLoopManager::GetBranchNames() if (fValidBranchNames.empty() && fTree) { fValidBranchNames = RDFInternal::GetBranchNames(*fTree, /*allowRepetitions=*/true); } + if (fValidBranchNames.empty() && fDataSource) { + fValidBranchNames = fDataSource->GetColumnNames(); + } return fValidBranchNames; } /// Return true if AddDataSourceColumnReaders was called for column name col. -bool RLoopManager::HasDataSourceColumnReaders(const std::string &col, const std::type_info &ti) const +bool RLoopManager::HasDataSourceColumnReaders(std::string_view col, const std::type_info &ti) const { const auto key = MakeDatasetColReadersKey(col, ti); assert(fDataSource != nullptr); // since data source column readers are always added for all slots at the same time, // if the reader is present for slot 0 we have it for all other slots as well. - return fDatasetColumnReaders[0].find(key) != fDatasetColumnReaders[0].end(); + auto it = fDatasetColumnReaders[0].find(key); + return (it != fDatasetColumnReaders[0].end() && it->second); } -void RLoopManager::AddDataSourceColumnReaders(const std::string &col, +void RLoopManager::AddDataSourceColumnReaders(std::string_view col, std::vector> &&readers, const std::type_info &ti) { @@ -1223,7 +1234,7 @@ void RLoopManager::AddDataSourceColumnReaders(const std::string &col, // Differently from AddDataSourceColumnReaders, this can be called from multiple threads concurrently /// \brief Register a new RTreeColumnReader with this RLoopManager. /// \return A shared pointer to the inserted column reader. -RColumnReaderBase *RLoopManager::AddTreeColumnReader(unsigned int slot, const std::string &col, +RColumnReaderBase *RLoopManager::AddTreeColumnReader(unsigned int slot, std::string_view col, std::unique_ptr &&reader, const std::type_info &ti) { @@ -1236,12 +1247,25 @@ RColumnReaderBase *RLoopManager::AddTreeColumnReader(unsigned int slot, const st return rptr; } +RColumnReaderBase *RLoopManager::AddDataSourceColumnReader(unsigned int slot, std::string_view col, + const std::type_info &ti, TTreeReader *treeReader) +{ + auto &readers = fDatasetColumnReaders[slot]; + const auto key = MakeDatasetColReadersKey(col, ti); + // if a reader for this column and this slot was already there, we are doing something wrong + assert(readers.find(key) == readers.end() || readers[key] == nullptr); + assert(fDataSource && "Missing RDataSource to add column reader."); + + readers[key] = ROOT::Internal::RDF::CreateColumnReader(*fDataSource, slot, col, ti, treeReader); + + return readers[key].get(); +} + RColumnReaderBase * -RLoopManager::GetDatasetColumnReader(unsigned int slot, const std::string &col, const std::type_info &ti) const +RLoopManager::GetDatasetColumnReader(unsigned int slot, std::string_view col, const std::type_info &ti) const { const auto key = MakeDatasetColReadersKey(col, ti); - auto it = fDatasetColumnReaders[slot].find(key); - if (it != fDatasetColumnReaders[slot].end()) + if (auto it = fDatasetColumnReaders[slot].find(key); it != fDatasetColumnReaders[slot].end() && it->second) return it->second.get(); else return nullptr; @@ -1324,11 +1348,9 @@ ROOT::Detail::RDF::CreateLMFromTTree(std::string_view datasetName, std::string_v if (checkFile) { OpenFileWithSanityChecks(fileNameGlob); } - std::string datasetNameInt{datasetName}; - std::string fileNameGlobInt{fileNameGlob}; - auto chain = ROOT::Internal::TreeUtils::MakeChainForMT(datasetNameInt.c_str()); - chain->Add(fileNameGlobInt.c_str()); - auto lm = std::make_shared(std::move(chain), defaultColumns); + + auto dataSource = std::make_unique(datasetName, fileNameGlob); + auto lm = std::make_shared(std::move(dataSource), defaultColumns); return lm; } @@ -1344,11 +1366,8 @@ ROOT::Detail::RDF::CreateLMFromTTree(std::string_view datasetName, const std::ve if (checkFile) { OpenFileWithSanityChecks(fileNameGlobs[0]); } - std::string treeNameInt(datasetName); - auto chain = ROOT::Internal::TreeUtils::MakeChainForMT(treeNameInt); - for (auto &f : fileNameGlobs) - chain->Add(f.c_str()); - auto lm = std::make_shared(std::move(chain), defaultColumns); + auto dataSource = std::make_unique(datasetName, fileNameGlobs); + auto lm = std::make_shared(std::move(dataSource), defaultColumns); return lm; } @@ -1420,3 +1439,81 @@ void ROOT::Detail::RDF::RLoopManager::SetDataSource(std::unique_ptr &entryRange, + ROOT::Internal::RSlotStack &slotStack, + std::atomic &entryCount) +{ +#ifdef R__USE_IMT + ROOT::Internal::RSlotStackRAII slotRAII(slotStack); + const auto &slot = slotRAII.fSlot; + + const auto &[start, end] = entryRange; + const auto nEntries = end - start; + entryCount.fetch_add(nEntries); + + RCallCleanUpTask cleanup(*this, slot); + RDSRangeRAII _{*this, slot, start}; + + R__LOG_DEBUG(0, RDFLogChannel()) << LogRangeProcessing({fDataSource->GetLabel(), start, end, slot}); + + try { + for (auto entry = start; entry < end; ++entry) { + if (fDataSource->SetEntry(slot, entry)) { + RunAndCheckFilters(slot, entry); + } + } + } catch (...) { + std::cerr << "RDataFrame::Run: event loop was interrupted\n"; + throw; + } + fDataSource->FinalizeSlot(slot); +#else + (void)entryRange; + (void)slotStack; + (void)entryCount; +#endif +} + +void ROOT::Detail::RDF::RLoopManager::TTreeThreadTask(TTreeReader &treeReader, ROOT::Internal::RSlotStack &slotStack, + std::atomic &entryCount) +{ +#ifdef R__USE_IMT + ROOT::Internal::RSlotStackRAII slotRAII(slotStack); + const auto &slot = slotRAII.fSlot; + + const auto entryRange = treeReader.GetEntriesRange(); // we trust TTreeProcessorMT to call SetEntriesRange + const auto &[start, end] = entryRange; + const auto nEntries = end - start; + auto count = entryCount.fetch_add(nEntries); + + RDSRangeRAII _{*this, slot, static_cast(start), &treeReader}; + RCallCleanUpTask cleanup(*this, slot, &treeReader); + + R__LOG_DEBUG(0, RDFLogChannel()) << LogRangeProcessing( + {fDataSource->GetLabel(), static_cast(start), static_cast(end), slot}); + try { + // recursive call to check filters and conditionally execute actions + while (validTTreeReaderRead(treeReader)) { + if (fNewSampleNotifier.CheckFlag(slot)) { + UpdateSampleInfo(slot, treeReader); + } + RunAndCheckFilters(slot, count++); + } + } catch (...) { + std::cerr << "RDataFrame::Run: event loop was interrupted\n"; + throw; + } + // fNStopsReceived < fNChildren is always true at the moment as we don't support event loop early quitting in + // multi-thread runs, but it costs nothing to be safe and future-proof in case we add support for that later. + if (treeReader.GetEntryStatus() != TTreeReader::kEntryBeyondEnd && fNStopsReceived < fNChildren) { + // something went wrong in the TTreeReader event loop + throw std::runtime_error("An error was encountered while processing the data. TTreeReader status code is: " + + std::to_string(treeReader.GetEntryStatus())); + } +#else + (void)treeReader; + (void)slotStack; + (void)entryCount; +#endif +} diff --git a/tree/dataframe/src/RTTreeDS.cxx b/tree/dataframe/src/RTTreeDS.cxx new file mode 100644 index 0000000000000..cf2a4c5c4c8ab --- /dev/null +++ b/tree/dataframe/src/RTTreeDS.cxx @@ -0,0 +1,434 @@ +#include + +#include // GetTopLevelBranchNames +#include +#include +#include +#include // GetBranchNames +#include +#include // GetBranchOrLeafTypeName + +#include +#include +#include +#include + +#ifdef R__USE_IMT +#include +#include +#include +#include +#endif + +namespace { +bool ValidRead(TTreeReader::EEntryStatus entryStatus) +{ + switch (entryStatus) { + case TTreeReader::kEntryValid: return true; + case TTreeReader::kIndexedFriendNoMatch: return true; + case TTreeReader::kMissingBranchWhenSwitchingTree: return true; + default: return false; + } +} + +std::tuple +GetCollectionInfo(const std::string &typeName) +{ + const auto beginType = typeName.substr(0, typeName.find_first_of('<') + 1); + + // Find TYPE from ROOT::RVec + if (auto pos = beginType.find("RVec<"); pos != std::string::npos) { + const auto begin = typeName.find_first_of('<', pos) + 1; + const auto end = typeName.find_last_of('>'); + const auto innerTypeName = typeName.substr(begin, end - begin); + if (innerTypeName.find("bool") != std::string::npos) + return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVecBool}; + else + return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; + } + + // Find TYPE from std::array + if (auto pos = beginType.find("array<"); pos != std::string::npos) { + const auto begin = typeName.find_first_of('<', pos) + 1; + const auto end = typeName.find_last_of('>'); + const auto arrTemplArgs = typeName.substr(begin, end - begin); + const auto lastComma = arrTemplArgs.find_last_of(','); + return {true, arrTemplArgs.substr(0, lastComma), + ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdArray}; + } + + return {false, "", ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; +} +} // namespace + +// Destructor is defined here, where the data member types are actually available +ROOT::Internal::RDF::RTTreeDS::~RTTreeDS() = default; + +void ROOT::Internal::RDF::RTTreeDS::Setup(std::shared_ptr &&tree, const ROOT::TreeUtils::RFriendInfo *friendInfo) +{ + fTree = tree; + + if (friendInfo) { + fFriends = ROOT::Internal::TreeUtils::MakeFriends(*friendInfo); + for (std::size_t i = 0ul; i < fFriends.size(); i++) { + const auto &thisFriendAlias = friendInfo->fFriendNames[i].second; + fTree->AddFriend(fFriends[i].get(), thisFriendAlias.c_str()); + } + } + + if (fBranchNamesWithDuplicates.empty()) + fBranchNamesWithDuplicates = ROOT::Internal::RDF::GetBranchNames(*fTree); + if (fBranchNamesWithoutDuplicates.empty()) + fBranchNamesWithoutDuplicates = ROOT::Internal::RDF::GetBranchNames(*fTree, /*allowDuplicates*/ false); + if (fTopLevelBranchNames.empty()) + fTopLevelBranchNames = ROOT::Internal::TreeUtils::GetTopLevelBranchNames(*fTree); +} + +ROOT::Internal::RDF::RTTreeDS::RTTreeDS(std::shared_ptr tree) +{ + assert(tree && "No tree passed to the constructor of RTTreeDS!"); + Setup(std::move(tree)); +} + +ROOT::Internal::RDF::RTTreeDS::RTTreeDS(std::shared_ptr tree, const ROOT::TreeUtils::RFriendInfo &friendInfo) +{ + assert(tree && "No tree passed to the constructor of RTTreeDS!"); + Setup(std::move(tree), &friendInfo); +} + +ROOT::Internal::RDF::RTTreeDS::RTTreeDS(std::string_view treeName, TDirectory *dirPtr) +{ + if (!dirPtr) { + throw std::runtime_error("RDataFrame: invalid TDirectory when constructing the data source."); + } + const std::string treeNameInt(treeName); + auto tree = dirPtr->Get(treeName.data()); + if (!tree) { + throw std::runtime_error("RDataFrame: TTree dataset '" + std::string(treeName) + "' cannot be found in '" + + dirPtr->GetName() + "'."); + } + Setup(ROOT::Internal::RDF::MakeAliasedSharedPtr(tree)); +} + +ROOT::Internal::RDF::RTTreeDS::RTTreeDS(std::string_view treeName, std::string_view fileNameGlob) +{ + std::string treeNameInt{treeName}; + std::string fileNameGlobInt{fileNameGlob}; + auto chain = ROOT::Internal::TreeUtils::MakeChainForMT(treeNameInt.c_str()); + chain->Add(fileNameGlobInt.c_str()); + + Setup(std::move(chain)); +} + +ROOT::Internal::RDF::RTTreeDS::RTTreeDS(std::string_view treeName, const std::vector &fileNameGlobs) +{ + std::string treeNameInt(treeName); + auto chain = ROOT::Internal::TreeUtils::MakeChainForMT(treeNameInt); + for (auto &&f : fileNameGlobs) + chain->Add(f.c_str()); + + Setup(std::move(chain)); +} + +ROOT::RDataFrame ROOT::Internal::RDF::FromTTree(std::string_view treeName, std::string_view fileNameGlob) +{ + return ROOT::RDataFrame(std::make_unique(treeName, fileNameGlob)); +} + +ROOT::RDataFrame +ROOT::Internal::RDF::FromTTree(std::string_view treeName, const std::vector &fileNameGlobs) +{ + return ROOT::RDataFrame(std::make_unique(treeName, fileNameGlobs)); +} + +ROOT::RDF::RSampleInfo ROOT::Internal::RDF::RTTreeDS::CreateSampleInfo( + const std::unordered_map &sampleMap) const +{ + // one GetTree to retrieve the TChain, another to retrieve the underlying TTree + auto *tree = fTreeReader->GetTree()->GetTree(); + // tree might be missing e.g. when a file in a chain does not exist + if (!tree) + return ROOT::RDF::RSampleInfo{}; + + const std::string treename = ROOT::Internal::TreeUtils::GetTreeFullPaths(*tree)[0]; + auto *file = tree->GetCurrentFile(); + const std::string fname = file != nullptr ? file->GetName() : "#inmemorytree#"; + + std::pair range = fTreeReader->GetEntriesRange(); + R__ASSERT(range.first >= 0); + if (range.second == -1) { + range.second = tree->GetEntries(); // convert '-1', i.e. 'until the end', to the actual entry number + } + // If the tree is stored in a subdirectory, treename will be the full path to it starting with the root directory '/' + const std::string &id = fname + (treename.rfind('/', 0) == 0 ? "" : "/") + treename; + if (sampleMap.empty()) { + return RSampleInfo(id, range); + } else { + if (sampleMap.find(id) == sampleMap.end()) + throw std::runtime_error("Full sample identifier '" + id + "' cannot be found in the available samples."); + return RSampleInfo(id, range, sampleMap.at(id)); + } +} + +void ROOT::Internal::RDF::RTTreeDS::ProcessMT(ROOT::Detail::RDF::RLoopManager &lm) +{ +#ifdef R__USE_IMT + ROOT::Internal::RSlotStack slotStack(fNSlots); + std::atomic entryCount(0ull); + + const auto &entryList = fTree->GetEntryList() ? *fTree->GetEntryList() : TEntryList(); + const auto &suppressErrorsForMissingBranches = lm.GetSuppressErrorsForMissingBranches(); + auto tp{fGlobalEntryRange.has_value() + ? std::make_unique(*fTree, fNSlots, fGlobalEntryRange.value(), + suppressErrorsForMissingBranches) + : std::make_unique(*fTree, entryList, fNSlots, suppressErrorsForMissingBranches)}; + + tp->Process([&lm, &slotStack, &entryCount](TTreeReader &treeReader) { + lm.TTreeThreadTask(treeReader, slotStack, entryCount); + }); + + if (fGlobalEntryRange.has_value()) { + auto &&[begin, end] = fGlobalEntryRange.value(); + auto &&processedEntries = entryCount.load(); + if ((end - begin) > processedEntries) { + Warning("RDataFrame::Run", + "RDataFrame stopped processing after %lld entries, whereas an entry range (begin=%lld,end=%lld) was " + "requested. Consider adjusting the end value of the entry range to a maximum of %lld.", + processedEntries, begin, end, begin + processedEntries); + } + } +#else + (void)lm; +#endif +} + +std::size_t ROOT::Internal::RDF::RTTreeDS::GetNFiles() const +{ + assert(fTree && "The internal TTree is not available, something went wrong."); + if (dynamic_cast(fTree.get())) + return ROOT::Internal::TreeUtils::GetFileNamesFromTree(*fTree).size(); + + return fTree->GetCurrentFile() ? 1 : 0; +} + +std::string ROOT::Internal::RDF::RTTreeDS::DescribeDataset() +{ + assert(fTree && "The internal TTree is not available, something went wrong."); + const auto treeName = fTree->GetName(); + const auto isTChain = dynamic_cast(fTree.get()) ? true : false; + const auto treeType = isTChain ? "TChain" : "TTree"; + const auto isInMemory = !isTChain && !fTree->GetCurrentFile() ? true : false; + const auto friendInfo = ROOT::Internal::TreeUtils::GetFriendInfo(*fTree); + const auto hasFriends = friendInfo.fFriendNames.empty() ? false : true; + std::stringstream ss; + ss << "Dataframe from " << treeType; + if (*treeName != 0) { + ss << " " << treeName; + } + if (isInMemory) { + ss << " (in-memory)"; + } else { + const auto files = ROOT::Internal::TreeUtils::GetFileNamesFromTree(*fTree); + const auto numFiles = files.size(); + if (numFiles == 1) { + ss << " in file " << files[0]; + } else { + ss << " in files\n"; + for (auto i = 0u; i < numFiles; i++) { + ss << " " << files[i]; + if (i < numFiles - 1) + ss << '\n'; + } + } + } + if (hasFriends) { + const auto numFriends = friendInfo.fFriendNames.size(); + if (numFriends == 1) { + ss << "\nwith friend\n"; + } else { + ss << "\nwith friends\n"; + } + for (auto i = 0u; i < numFriends; i++) { + const auto nameAlias = friendInfo.fFriendNames[i]; + const auto files = friendInfo.fFriendFileNames[i]; + const auto numFiles = files.size(); + const auto subnames = friendInfo.fFriendChainSubNames[i]; + ss << " " << nameAlias.first; + if (nameAlias.first != nameAlias.second) + ss << " (" << nameAlias.second << ")"; + // case: TTree as friend + if (numFiles == 1) { + ss << " " << files[0]; + } + // case: TChain as friend + else { + ss << '\n'; + for (auto j = 0u; j < numFiles; j++) { + ss << " " << subnames[j] << " " << files[j]; + if (j < numFiles - 1) + ss << '\n'; + } + } + if (i < numFriends - 1) + ss << '\n'; + } + } + return ss.str(); +} + +std::unique_ptr +ROOT::Internal::RDF::RTTreeDS::CreateColumnReader(unsigned int /*slot*/, std::string_view col, const std::type_info &ti, + TTreeReader *treeReader) +{ + // In a single thread run, use the TTreeReader data member. + if (fTreeReader) { + treeReader = fTreeReader.get(); + } + + // The TTreeReader might still not be available if CreateColumnReader was called before the start of the computation + // graph execution, e.g. in AddDSColumns. + if (!treeReader) + return nullptr; + + if (ti == typeid(void)) + return std::make_unique(*treeReader, col); + + const auto typeName = ROOT::Internal::RDF::TypeID2TypeName(ti); + if (auto &&[toConvert, innerTypeName, collType] = GetCollectionInfo(typeName); toConvert) + return std::make_unique(*treeReader, col, innerTypeName, + collType); + else + return std::make_unique(*treeReader, col, typeName); +} + +bool ROOT::Internal::RDF::RTTreeDS::SetEntry(unsigned int, ULong64_t entry) +{ + // The first entry of each tree in a chain is read in GetEntryRanges, we avoid repeating it here + if (fTreeReader->GetCurrentEntry() != static_cast(entry)) + fTreeReader->SetEntry(entry); + return ValidRead(fTreeReader->GetEntryStatus()); +} + +std::string ROOT::Internal::RDF::RTTreeDS::GetTypeNameWithOpts(std::string_view colName, bool vector2RVec) const +{ + auto colTypeName = ROOT::Internal::RDF::GetBranchOrLeafTypeName(*fTree, std::string(colName)); + if (vector2RVec && TClassEdit::IsSTLCont(colTypeName) == ROOT::ESTLType::kSTLvector) { + std::vector split; + int dummy; + TClassEdit::GetSplit(colTypeName.c_str(), split, dummy); + auto &valueType = split[1]; + colTypeName = "ROOT::VecOps::RVec<" + valueType + ">"; + } + return colTypeName; +} + +std::string ROOT::Internal::RDF::RTTreeDS::GetTypeName(std::string_view colName) const +{ + auto colTypeName = ROOT::Internal::RDF::GetBranchOrLeafTypeName(*fTree, std::string(colName)); + if (TClassEdit::IsSTLCont(colTypeName) == ROOT::ESTLType::kSTLvector) { + std::vector split; + int dummy; + TClassEdit::GetSplit(colTypeName.c_str(), split, dummy); + auto &valueType = split[1]; + colTypeName = "ROOT::VecOps::RVec<" + valueType + ">"; + } + return colTypeName; +} + +std::vector> ROOT::Internal::RDF::RTTreeDS::GetTTreeEntryRange(TTree &tree) +{ + // Restrict the range to the global range if available + const ULong64_t rangeBegin = fGlobalEntryRange.has_value() ? std::max(0ull, fGlobalEntryRange->first) : 0ull; + const ULong64_t rangeEnd = fGlobalEntryRange.has_value() + ? std::min(static_cast(tree.GetEntries()), fGlobalEntryRange->second) + : static_cast(tree.GetEntries()); + return std::vector>{{rangeBegin, rangeEnd}}; +} + +std::vector> ROOT::Internal::RDF::RTTreeDS::GetTChainEntryRange(TChain &chain) +{ + // We are either at a complete new beginning (entry == -1) or at the + // end of processing of the previous tree in the chain. Go to the next + // entry, which should always be the first entry in a tree. This allows + // to get the proper tree offset for the range. + fTreeReader->Next(); + if (!ValidRead(fTreeReader->GetEntryStatus())) + return {}; + auto treeOffsets = chain.GetTreeOffset(); + auto treeNumber = chain.GetTreeNumber(); + const ULong64_t thisTreeBegin = treeOffsets[treeNumber]; + const ULong64_t thisTreeEnd = treeOffsets[treeNumber + 1]; + // Restrict the range to the global range if available + const ULong64_t rangeBegin = + fGlobalEntryRange.has_value() ? std::max(thisTreeBegin, fGlobalEntryRange->first) : thisTreeBegin; + const ULong64_t rangeEnd = + fGlobalEntryRange.has_value() ? std::min(thisTreeEnd, fGlobalEntryRange->second) : thisTreeEnd; + return std::vector>{{rangeBegin, rangeEnd}}; +} + +std::vector> ROOT::Internal::RDF::RTTreeDS::GetEntryRanges() +{ + assert(fTreeReader && "TTreeReader is not available, this should never happen."); + auto treeOrChain = fTreeReader->GetTree(); + assert(treeOrChain && "Could not retrieve TTree from TTreeReader, something went wrong."); + + // End of dataset or entry range + if (fTreeReader->GetCurrentEntry() >= treeOrChain->GetEntriesFast() - 1 || + (fGlobalEntryRange.has_value() && + (static_cast(fTreeReader->GetCurrentEntry()) >= fGlobalEntryRange->first && + static_cast(fTreeReader->GetCurrentEntry()) == fGlobalEntryRange->second - 1))) { + // Place the TTreeReader beyond the end of the dataset, so RunFinalChecks can work properly + fTreeReader->Next(); + return {}; + } + + if (auto chain = dynamic_cast(treeOrChain)) { + return GetTChainEntryRange(*chain); + } else { + return GetTTreeEntryRange(*treeOrChain); + } +} + +void ROOT::Internal::RDF::RTTreeDS::Finalize() +{ + // At the end of the event loop, reset the TTreeReader to be ready for + // a possible new run. + if (fTreeReader) + fTreeReader.reset(); +} + +void ROOT::Internal::RDF::RTTreeDS::Initialize() +{ + if (fNSlots == 1) { + assert(!fTreeReader); + fTreeReader = std::make_unique(fTree.get(), fTree->GetEntryList(), /*warnAboutLongerFriends*/ true); + if (fGlobalEntryRange.has_value() && fGlobalEntryRange->first <= std::numeric_limits::max() && + fGlobalEntryRange->second <= std::numeric_limits::max() && fTreeReader && + fTreeReader->SetEntriesRange(fGlobalEntryRange->first, fGlobalEntryRange->second) != + TTreeReader::kEntryValid) { + throw std::logic_error("Something went wrong in initializing the TTreeReader."); + } + } +} + +void ROOT::Internal::RDF::RTTreeDS::InitializeWithOpts(const std::set &suppressErrorsForMissingBranches) +{ + Initialize(); + if (fTreeReader) + fTreeReader->SetSuppressErrorsForMissingBranches(suppressErrorsForMissingBranches); +} + +void ROOT::Internal::RDF::RTTreeDS::RunFinalChecks(bool nodesLeftNotRun) const +{ + if (fTreeReader->GetEntryStatus() != TTreeReader::kEntryBeyondEnd && nodesLeftNotRun) { + // something went wrong in the TTreeReader event loop + throw std::runtime_error("An error was encountered while processing the data. TTreeReader status code is: " + + std::to_string(fTreeReader->GetEntryStatus())); + } +} + +TTree *ROOT::Internal::RDF::RTTreeDS::GetTree() +{ + assert(fTree); + return fTree.get(); +} diff --git a/tree/dataframe/test/dataframe_datasetspec.cxx b/tree/dataframe/test/dataframe_datasetspec.cxx index f7f95d8023217..35878c1326300 100644 --- a/tree/dataframe/test/dataframe_datasetspec.cxx +++ b/tree/dataframe/test/dataframe_datasetspec.cxx @@ -566,7 +566,7 @@ TEST_P(RDatasetSpecTest, SaveGraph) static const std::string expectedGraph( "digraph {\n" "\t1 [label=\"Sum\", style=\"filled\", fillcolor=\"#e47c7e\", shape=\"box\"];\n" - "\t0 [label=\"TChain\", style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n" + "\t0 [label=\"TTreeDS\", style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n" "\t2 [label=\"Sum\", style=\"filled\", fillcolor=\"#e47c7e\", shape=\"box\"];\n" "\t0 -> 1;\n" "\t0 -> 2;\n" diff --git a/tree/dataframe/test/dataframe_helpers.cxx b/tree/dataframe/test/dataframe_helpers.cxx index cd258e50cd3de..6c2af519c3f83 100755 --- a/tree/dataframe/test/dataframe_helpers.cxx +++ b/tree/dataframe/test/dataframe_helpers.cxx @@ -358,9 +358,9 @@ TEST(RDFHelpers, SaveGraphRootFromTree) t.Write(); f.Close(); - static const std::string expectedGraph( - "digraph {\n\t1 [label=\"Count\", style=\"filled\", fillcolor=\"#e47c7e\", shape=\"box\"];\n\t0 [label=\"t\", " - "style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n\t0 -> 1;\n}"); + static const std::string expectedGraph("digraph {\n\t1 [label=\"Count\", style=\"filled\", fillcolor=\"#e47c7e\", " + "shape=\"box\"];\n\t0 [label=\"TTreeDS\", " + "style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n\t0 -> 1;\n}"); ROOT::RDataFrame df("t", "savegraphrootfromtree.root"); auto c = df.Count(); @@ -381,9 +381,9 @@ TEST(RDFHelpers, SaveGraphToFile) t.Write(); f.Close(); - static const std::string expectedGraph( - "digraph {\n\t1 [label=\"Count\", style=\"filled\", fillcolor=\"#e47c7e\", shape=\"box\"];\n\t0 [label=\"t\", " - "style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n\t0 -> 1;\n}"); + static const std::string expectedGraph("digraph {\n\t1 [label=\"Count\", style=\"filled\", fillcolor=\"#e47c7e\", " + "shape=\"box\"];\n\t0 [label=\"TTreeDS\", " + "style=\"filled\", fillcolor=\"#f4b400\", shape=\"ellipse\"];\n\t0 -> 1;\n}"); ROOT::RDataFrame df("t", "savegraphtofile.root"); auto c = df.Count(); diff --git a/tree/dataframe/test/dataframe_interface.cxx b/tree/dataframe/test/dataframe_interface.cxx index 1120887071d98..be4043251eb3b 100644 --- a/tree/dataframe/test/dataframe_interface.cxx +++ b/tree/dataframe/test/dataframe_interface.cxx @@ -922,7 +922,7 @@ TEST(RDataFrameInterface, PrintValueFromTree) TTree t("t", "t"); RDataFrame df(t); auto printValue = cling::printValue(&df); - EXPECT_EQ(printValue, "A data frame built on top of the t dataset."); + EXPECT_EQ(printValue, "A data frame associated to the data source \"TTree data source\""); } TEST(RDataFrameInterface, PrintValueNoData) @@ -960,7 +960,10 @@ TEST(RDataFrameInterface, GetNFilesFromTChain) { std::vector filenames{"GetNFilesFromTChain1.root", "GetNFilesFromTChain2.root", "GetNFilesFromTChain3.root"}; - TChain c{"chain"}; + TreeInFileRAII r1{filenames[0]}; + TreeInFileRAII r2{filenames[1]}; + TreeInFileRAII r3{filenames[2]}; + TChain c{"t"}; for (const auto &fn : filenames) c.Add(fn.c_str()); ROOT::RDataFrame df{c}; diff --git a/tree/treeplayer/inc/TTreeReader.h b/tree/treeplayer/inc/TTreeReader.h index 5ba3ed17e3690..d2e35ccf82e5d 100644 --- a/tree/treeplayer/inc/TTreeReader.h +++ b/tree/treeplayer/inc/TTreeReader.h @@ -260,6 +260,11 @@ class TTreeReader : public TObject { /// Return an iterator beyond the last TTree entry. Iterator_t end() { return Iterator_t(*this, -1); } + void SetSuppressErrorsForMissingBranches(const std::set &suppressErrorsForMissingBranches) + { + fSuppressErrorsForMissingBranches = suppressErrorsForMissingBranches; + } + protected: using NamedProxies_t = std::unordered_map>; void Initialize();