Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 10 additions & 26 deletions tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,9 @@ BookVariationJit(const std::vector<std::string> &colNames, std::string_view vari
RDataSource *ds, const RColumnRegister &colRegister, std::shared_ptr<RNodeBase> *upcastNodeOnHeap,
bool isSingleColumn);

std::string JitBuildAction(const ColumnNames_t &bl, std::shared_ptr<RDFDetail::RNodeBase> *prevNode,
const std::type_info &art, const std::type_info &at, void *rOnHeap, TTree *tree,
std::string JitBuildAction(const ColumnNames_t &bl, const std::type_info &art, const std::type_info &at, TTree *tree,
const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds,
std::weak_ptr<RJittedAction> *jittedActionOnHeap, const bool vector2RVec = true);
const bool vector2RVec = true);

// Allocate a weak_ptr on the heap, return a pointer to it. The user is responsible for deleting this weak_ptr.
// This function is meant to be used by RInterface's methods that book code for jitting.
Expand Down Expand Up @@ -473,7 +472,7 @@ void AddDSColumns(const std::vector<std::string> &requiredCols, ROOT::Detail::RD

// this function is meant to be called by the jitted code generated by BookFilterJit
template <typename F, typename PrevNode>
void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name,
void JitFilterHelper(F &&f, const ColumnNames_t &cols, std::string_view name,
std::weak_ptr<RJittedFilter> *wkJittedFilter, std::shared_ptr<PrevNode> *prevNodeOnHeap,
RColumnRegister *colRegister) noexcept
{
Expand All @@ -486,9 +485,6 @@ void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);
delete[] colsPtr;

const auto jittedFilter = wkJittedFilter->lock();

// mock Filter logic -- validity checks and Define-ition of RDataSource columns
Expand Down Expand Up @@ -538,7 +534,7 @@ auto MakeDefineNode(DefineTypes::RDefinePerSampleTag, std::string_view name, std
// This function is meant to be called by jitted code right before starting the event loop.
// If colsPtr is null, build a RDefinePerSample (it has no input columns), otherwise a RDefine.
template <typename RDefineTypeTag, typename F>
void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name, RLoopManager *lm,
void JitDefineHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RLoopManager *lm,
std::weak_ptr<RJittedDefine> *wkJittedDefine, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
{
Expand All @@ -547,7 +543,6 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
delete wkJittedDefine;
delete colRegister;
delete prevNodeOnHeap;
delete[] colsPtr;
};

if (wkJittedDefine->expired()) {
Expand All @@ -557,15 +552,13 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedDefine = wkJittedDefine->lock();

using Callable_t = std::decay_t<F>;
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;

auto ds = lm->GetDataSource();
if (ds != nullptr && colsPtr)
if (ds != nullptr)
AddDSColumns(cols, *lm, *ds, ColTypes_t(), *colRegister);

// will never actually be used (trumped by jittedDefine->GetTypeName()), but we set it to something meaningful
Expand All @@ -580,18 +573,14 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
}

template <bool IsSingleColumn, typename F>
void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const char **variedCols,
std::size_t variedColsSize, const char **variationTags, std::size_t variationTagsSize,
std::string_view variationName, RLoopManager *lm,
std::weak_ptr<RJittedVariation> *wkJittedVariation, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, const ColumnNames_t &variedColNames,
const char **variationTags, std::size_t variationTagsSize, std::string_view variationName,
RLoopManager *lm, std::weak_ptr<RJittedVariation> *wkJittedVariation,
RColumnRegister *colRegister, std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete[] variedCols;
delete[] variationTags;

delete wkJittedVariation;
delete colRegister;
delete prevNodeOnHeap;
Expand All @@ -604,8 +593,6 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const
return;
}

const ColumnNames_t inputColNames(colsPtr, colsPtr + colsSize);
std::vector<std::string> variedColNames(variedCols, variedCols + variedColsSize);
std::vector<std::string> tags(variationTags, variationTags + variationTagsSize);

auto jittedVariation = wkJittedVariation->lock();
Expand All @@ -628,13 +615,12 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const

/// Convenience function invoked by jitted code to build action nodes at runtime
template <typename ActionTag, typename... ColTypes, typename PrevNodeType, typename HelperArgType>
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const char **colsPtr, std::size_t colsSize,
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const ColumnNames_t &cols,
const unsigned int nSlots, std::shared_ptr<HelperArgType> *helperArgOnHeap,
std::weak_ptr<RJittedAction> *wkJittedActionOnHeap, RColumnRegister *colRegister) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete helperArgOnHeap;
delete wkJittedActionOnHeap;
// colRegister must be deleted before prevNodeOnHeap because their dtor needs the RLoopManager to be alive
Expand All @@ -650,8 +636,6 @@ void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const char *
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedActionOnHeap = wkJittedActionOnHeap->lock();

// if we are here it means we are jitting, if we are jitting the loop manager must be alive
Expand Down
9 changes: 5 additions & 4 deletions tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ protected:
fColRegister, proxiedPtr->GetVariations());
auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);

auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType),
typeid(ActionTag), helperArgOnHeap, nullptr, nSlots, fColRegister,
GetDataSource(), jittedActionOnHeap, vector2RVec);
fLoopManager->ToJitExec(toJit);
auto definesCopy = new RDFInternal::RColumnRegister(fColRegister); // deleted in jitted call
auto funcBody = RDFInternal::JitBuildAction(validColumnNames, typeid(HelperArgType), typeid(ActionTag), nullptr,
nSlots, fColRegister, GetDataSource(), vector2RVec);
fLoopManager->RegisterJitHelperCall(funcBody, upcastNodeOnHeap, definesCopy, validColumnNames, jittedActionOnHeap,
helperArgOnHeap);
return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
}

Expand Down
26 changes: 26 additions & 0 deletions tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RActionBase;
class RVariationBase;
class RDefinesWithReaders;
class RVariationsWithReaders;
class RColumnRegister;

namespace GraphDrawing {
class GraphCreatorHelper;
Expand Down Expand Up @@ -201,6 +202,27 @@ class RLoopManager : public RNodeBase {
std::set<std::pair<std::string_view, std::unique_ptr<ROOT::Internal::RDF::RVariationsWithReaders>>>
fUniqueVariationsWithReaders;

// deferred function calls to Jitted functions
struct DeferredJitCall {
std::string functionId;
std::shared_ptr<RNodeBase> *prevNodeOnHeap;
ROOT::Internal::RDF::RColumnRegister *colRegister;
std::vector<std::string> colNames;
void *wkJittedNode, *argument;
DeferredJitCall(const std::string &id, std::shared_ptr<RNodeBase> *prevNode,
ROOT::Internal::RDF::RColumnRegister *cols, const std::vector<std::string> &colnames,
void *wkNodePtr, void *arg)
: functionId(id),
prevNodeOnHeap(prevNode),
colRegister(cols),
colNames(colnames),
wkJittedNode(wkNodePtr),
argument(arg)
{
}
};
std::vector<DeferredJitCall> fJitHelperCalls;

public:
RLoopManager(const ColumnNames_t &defaultColumns = {});
RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches);
Expand All @@ -217,6 +239,7 @@ public:
~RLoopManager() override;

void Jit();
void RunDeferredCalls();
RLoopManager *GetLoopManagerUnchecked() final { return this; }
void Run(bool jit = true);
const ColumnNames_t &GetDefaultColumnNames() const;
Expand All @@ -240,6 +263,9 @@ public:
void IncrChildrenCount() final { ++fNChildren; }
void StopProcessing() final { ++fNStopsReceived; }
void ToJitExec(const std::string &) const;
void RegisterJitHelperCall(const std::string &funcBody, std::shared_ptr<RNodeBase> *prevNodeOnHeap,
ROOT::Internal::RDF::RColumnRegister *colRegister,
const std::vector<std::string> &colNames, void *wkJittedPtr, void *argument = nullptr);
void RegisterCallback(ULong64_t everyNEvents, std::function<void(unsigned int)> &&f);
unsigned int GetNRuns() const { return fNRuns; }
bool HasDataSourceColumnReaders(std::string_view col, const std::type_info &ti) const;
Expand Down
Loading