Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 45 additions & 119 deletions tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -367,28 +367,24 @@ void CheckForNoVariations(const std::string &where, std::string_view definedColV

std::string PrettyPrintAddr(const void *const addr);

std::shared_ptr<RJittedFilter> BookFilterJit(std::shared_ptr<RNodeBase> *prevNodeOnHeap, std::string_view name,
std::shared_ptr<RJittedFilter> BookFilterJit(std::shared_ptr<RNodeBase> prevNode, std::string_view name,
std::string_view expression, const RColumnRegister &colRegister,
TTree *tree, RDataSource *ds);

std::shared_ptr<RJittedDefine> BookDefineJit(std::string_view name, std::string_view expression, RLoopManager &lm,
RDataSource *ds, const RColumnRegister &colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap);
RDataSource *ds, const RColumnRegister &colRegister);

std::shared_ptr<RJittedDefine> BookDefinePerSampleJit(std::string_view name, std::string_view expression,
RLoopManager &lm, const RColumnRegister &colRegister,
std::shared_ptr<RNodeBase> *upcastNodeOnHeap);
RLoopManager &lm, const RColumnRegister &colRegister);

std::shared_ptr<RJittedVariation>
BookVariationJit(const std::vector<std::string> &colNames, std::string_view variationName,
const std::vector<std::string> &variationTags, std::string_view expression, RLoopManager &lm,
RDataSource *ds, const RColumnRegister &colRegister, std::shared_ptr<RNodeBase> *upcastNodeOnHeap,
bool isSingleColumn);
RDataSource *ds, const RColumnRegister &colRegister, bool isSingleColumn);

std::string JitBuildAction(const ColumnNames_t &bl, std::shared_ptr<RDFDetail::RNodeBase> *prevNode,
const std::type_info &art, const std::type_info &at, void *rOnHeap, TTree *tree,
std::string JitBuildAction(const ColumnNames_t &bl, const std::type_info &art, const std::type_info &at, TTree *tree,
const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds,
std::weak_ptr<RJittedAction> *jittedActionOnHeap, const bool vector2RVec = true);
const bool vector2RVec = true);

// Allocate a weak_ptr on the heap, return a pointer to it. The user is responsible for deleting this weak_ptr.
// This function is meant to be used by RInterface's methods that book code for jitting.
Expand Down Expand Up @@ -472,45 +468,32 @@ void AddDSColumns(const std::vector<std::string> &requiredCols, ROOT::Detail::RD
ROOT::Internal::RDF::RColumnRegister &colRegister);

// this function is meant to be called by the jitted code generated by BookFilterJit
template <typename F, typename PrevNode>
void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name,
std::weak_ptr<RJittedFilter> *wkJittedFilter, std::shared_ptr<PrevNode> *prevNodeOnHeap,
RColumnRegister *colRegister) noexcept
template <typename F>
void JitFilterHelper(F &&f, const ColumnNames_t &cols, RColumnRegister &colRegister,
ROOT::Detail::RDF::RLoopManager &lm, ROOT::Detail::RDF::RJittedFilter *jittedFilter) noexcept
{
if (wkJittedFilter->expired()) {
if (!jittedFilter) {
// The branch of the computation graph that needed this jitted code went out of scope between the type
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
delete wkJittedFilter;
delete colRegister;
delete prevNodeOnHeap;
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);
delete[] colsPtr;

const auto jittedFilter = wkJittedFilter->lock();

// mock Filter logic -- validity checks and Define-ition of RDataSource columns
using Callable_t = std::decay_t<F>;
using F_t = RFilter<Callable_t, PrevNode>;
auto prevNode = jittedFilter->MoveOutPrevNode();
using PrevNode_t = typename decltype(prevNode)::element_type;
using F_t = RFilter<Callable_t, PrevNode_t>;
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;
constexpr auto nColumns = ColTypes_t::list_size;
CheckFilter(f);

auto &lm = *jittedFilter->GetLoopManagerUnchecked(); // RLoopManager must exist at this time
auto ds = lm.GetDataSource();

if (ds != nullptr)
AddDSColumns(cols, lm, *ds, ColTypes_t(), *colRegister);
if (ds != nullptr && !cols.empty())
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);

jittedFilter->SetFilter(
std::unique_ptr<RFilterBase>(new F_t(std::forward<F>(f), cols, *prevNodeOnHeap, *colRegister, name)));
// colRegister points to the columns structure in the heap, created before the jitted call so that the jitter can
// share data after it has lazily compiled the code. Here the data has been used and the memory can be freed.
delete colRegister;
delete prevNodeOnHeap;
delete wkJittedFilter;
std::unique_ptr<RFilterBase>(new F_t(std::forward<F>(f), cols, prevNode, colRegister, jittedFilter->GetName())));
}

namespace DefineTypes {
Expand Down Expand Up @@ -538,136 +521,79 @@ auto MakeDefineNode(DefineTypes::RDefinePerSampleTag, std::string_view name, std
// This function is meant to be called by jitted code right before starting the event loop.
// If colsPtr is null, build a RDefinePerSample (it has no input columns), otherwise a RDefine.
template <typename RDefineTypeTag, typename F>
void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name, RLoopManager *lm,
std::weak_ptr<RJittedDefine> *wkJittedDefine, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
void JitDefineHelper(F &&f, const ColumnNames_t &cols, RColumnRegister &colRegister,
ROOT::Detail::RDF::RLoopManager &lm, ROOT::Detail::RDF::RJittedDefine *jittedDefine) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete wkJittedDefine;
delete colRegister;
delete prevNodeOnHeap;
delete[] colsPtr;
};

if (wkJittedDefine->expired()) {

if (!jittedDefine) {
// The branch of the computation graph that needed this jitted code went out of scope between the type
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
doDeletes();
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedDefine = wkJittedDefine->lock();

using Callable_t = std::decay_t<F>;
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;

auto ds = lm->GetDataSource();
if (ds != nullptr && colsPtr)
AddDSColumns(cols, *lm, *ds, ColTypes_t(), *colRegister);
auto ds = lm.GetDataSource();
if (ds != nullptr && !cols.empty())
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);

// will never actually be used (trumped by jittedDefine->GetTypeName()), but we set it to something meaningful
// to help devs debugging
const auto dummyType = "jittedCol_t";
// use unique_ptr<RDefineBase> instead of make_unique<NewCol_t> to reduce jit/compile-times
std::unique_ptr<RDefineBase> newCol{
MakeDefineNode(RDefineTypeTag{}, name, dummyType, std::forward<F>(f), cols, *colRegister, *lm)};
MakeDefineNode(RDefineTypeTag{}, jittedDefine->GetName(), dummyType, std::forward<F>(f), cols, colRegister, lm)};
jittedDefine->SetDefine(std::move(newCol));

doDeletes();
}

template <bool IsSingleColumn, typename F>
void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const char **variedCols,
std::size_t variedColsSize, const char **variationTags, std::size_t variationTagsSize,
std::string_view variationName, RLoopManager *lm,
std::weak_ptr<RJittedVariation> *wkJittedVariation, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, RColumnRegister &colRegister,
ROOT::Detail::RDF::RLoopManager &lm, RJittedVariation *jittedVariation,
const ColumnNames_t &variedColNames, const ColumnNames_t &variationTags) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete[] variedCols;
delete[] variationTags;

delete wkJittedVariation;
delete colRegister;
delete prevNodeOnHeap;
};

if (wkJittedVariation->expired()) {

if (!jittedVariation) {
// The branch of the computation graph that needed this jitted variation went out of scope between the type
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
doDeletes();
return;
}

const ColumnNames_t inputColNames(colsPtr, colsPtr + colsSize);
std::vector<std::string> variedColNames(variedCols, variedCols + variedColsSize);
std::vector<std::string> tags(variationTags, variationTags + variationTagsSize);

auto jittedVariation = wkJittedVariation->lock();

using Callable_t = std::decay_t<F>;
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;

auto ds = lm->GetDataSource();
if (ds != nullptr)
AddDSColumns(inputColNames, *lm, *ds, ColTypes_t(), *colRegister);
auto ds = lm.GetDataSource();
if (ds != nullptr && !inputColNames.empty())
AddDSColumns(inputColNames, lm, *ds, ColTypes_t(), colRegister);

// use unique_ptr<RDefineBase> instead of make_unique<NewCol_t> to reduce jit/compile-times
std::unique_ptr<RVariationBase> newVariation{new RVariation<std::decay_t<F>, IsSingleColumn>(
std::move(variedColNames), variationName, std::forward<F>(f), std::move(tags), jittedVariation->GetTypeName(),
*colRegister, *lm, inputColNames)};
variedColNames, jittedVariation->GetVariationName(), std::forward<F>(f), variationTags,
jittedVariation->GetTypeName(), colRegister, lm, inputColNames)};
jittedVariation->SetVariation(std::move(newVariation));

doDeletes();
}

/// Convenience function invoked by jitted code to build action nodes at runtime
template <typename ActionTag, typename... ColTypes, typename PrevNodeType, typename HelperArgType>
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const char **colsPtr, std::size_t colsSize,
const unsigned int nSlots, std::shared_ptr<HelperArgType> *helperArgOnHeap,
std::weak_ptr<RJittedAction> *wkJittedActionOnHeap, RColumnRegister *colRegister) noexcept
template <typename ActionTag, typename... ColTypes, typename HelperArgType>
void CallBuildAction(const ColumnNames_t &cols, RColumnRegister &colRegister, ROOT::Detail::RDF::RLoopManager &lm,
RJittedAction *jittedAction, unsigned int nSlots,
std::shared_ptr<HelperArgType> *helperArg) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete helperArgOnHeap;
delete wkJittedActionOnHeap;
// colRegister must be deleted before prevNodeOnHeap because their dtor needs the RLoopManager to be alive
// and prevNodeOnHeap is what keeps it alive if the rest of the computation graph is already out of scope
delete colRegister;
delete prevNodeOnHeap;
};

if (wkJittedActionOnHeap->expired()) {
if (!jittedAction) {
// The branch of the computation graph that needed this jitted variation went out of scope between the type
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
doDeletes();
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedActionOnHeap = wkJittedActionOnHeap->lock();

// if we are here it means we are jitting, if we are jitting the loop manager must be alive
auto &prevNodePtr = *prevNodeOnHeap;
auto &loopManager = *prevNodePtr->GetLoopManagerUnchecked();
using ColTypes_t = TypeList<ColTypes...>;
constexpr auto nColumns = ColTypes_t::list_size;
auto ds = loopManager.GetDataSource();
if (ds != nullptr)
AddDSColumns(cols, loopManager, *ds, ColTypes_t(), *colRegister);

auto actionPtr = BuildAction<ColTypes...>(cols, std::move(*helperArgOnHeap), nSlots, std::move(prevNodePtr),
ActionTag{}, *colRegister);
jittedActionOnHeap->SetAction(std::move(actionPtr));
auto ds = lm.GetDataSource();
if (ds != nullptr && !cols.empty())
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);

doDeletes();
auto actionPtr =
BuildAction<ColTypes...>(cols, *helperArg, nSlots, jittedAction->MoveOutPrevNode(), ActionTag{}, colRegister);
jittedAction->SetAction(std::move(actionPtr));
}

/// The contained `type` alias is `double` if `T == RInferredType`, `U` if `T == std::container<U>`, `T` otherwise.
Expand Down
23 changes: 6 additions & 17 deletions tree/dataframe/inc/ROOT/RDF/RInterface.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,8 @@ public:
/// ~~~
RInterface<RDFDetail::RJittedFilter> Filter(std::string_view expression, std::string_view name = "")
{
// deleted by the jitted call to JitFilterHelper
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
using BaseNodeType_t = typename std::remove_pointer_t<decltype(upcastNodeOnHeap)>::element_type;
RInterface<BaseNodeType_t> upcastInterface(*upcastNodeOnHeap, *fLoopManager, fColRegister);
const auto jittedFilter =
RDFInternal::BookFilterJit(upcastNodeOnHeap, name, expression, fColRegister, nullptr, GetDataSource());
const auto jittedFilter = RDFInternal::BookFilterJit(RDFInternal::UpcastNode(fProxiedPtr), name, expression,
fColRegister, nullptr, GetDataSource());

return RInterface<RDFDetail::RJittedFilter>(std::move(jittedFilter), *fLoopManager, fColRegister);
}
Expand Down Expand Up @@ -538,9 +534,7 @@ public:
RDFInternal::CheckForRedefinition(where, name, fColRegister,
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});

auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
auto jittedDefine =
RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister, upcastNodeOnHeap);
auto jittedDefine = RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister);

RDFInternal::RColumnRegister newCols(fColRegister);
newCols.AddDefine(std::move(jittedDefine));
Expand Down Expand Up @@ -628,9 +622,7 @@ public:
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
RDFInternal::CheckForNoVariations(where, name, fColRegister);

auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
auto jittedDefine =
RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister, upcastNodeOnHeap);
auto jittedDefine = RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister);

RDFInternal::RColumnRegister newCols(fColRegister);
newCols.AddDefine(std::move(jittedDefine));
Expand Down Expand Up @@ -805,9 +797,7 @@ public:
RDFInternal::CheckForRedefinition("DefinePerSample", name, fColRegister,
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});

auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
auto jittedDefine =
RDFInternal::BookDefinePerSampleJit(name, expression, *fLoopManager, fColRegister, upcastNodeOnHeap);
auto jittedDefine = RDFInternal::BookDefinePerSampleJit(name, expression, *fLoopManager, fColRegister);

RDFInternal::RColumnRegister newCols(fColRegister);
newCols.AddDefine(std::move(jittedDefine));
Expand Down Expand Up @@ -3415,10 +3405,9 @@ private:
throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
}

auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
auto jittedVariation =
RDFInternal::BookVariationJit(colNames, variationName, variationTags, expression, *fLoopManager,
GetDataSource(), fColRegister, upcastNodeOnHeap, isSingleColumn);
GetDataSource(), fColRegister, isSingleColumn);

RDFInternal::RColumnRegister newColRegister(fColRegister);
newColRegister.AddVariation(std::move(jittedVariation));
Expand Down
20 changes: 8 additions & 12 deletions tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,14 @@ protected:
const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
const unsigned int nSlots = fLoopManager->GetNSlots();

auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);

auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));

const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
fColRegister, proxiedPtr->GetVariations());
auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);

auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType),
typeid(ActionTag), helperArgOnHeap, nullptr, nSlots, fColRegister,
GetDataSource(), jittedActionOnHeap, vector2RVec);
fLoopManager->ToJitExec(toJit);
const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(
*fLoopManager, validColumnNames, fColRegister, proxiedPtr->GetVariations(), proxiedPtr);

auto funcBody = RDFInternal::JitBuildAction(validColumnNames, typeid(HelperArgType), typeid(ActionTag), nullptr,
nSlots, fColRegister, GetDataSource(), vector2RVec);
fLoopManager->RegisterJitHelperCall(funcBody,
std::make_unique<ROOT::Internal::RDF::RColumnRegister>(fColRegister),
validColumnNames, jittedAction, helperArg);
return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
}

Expand Down
6 changes: 5 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RJittedAction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace ROOT {
namespace Detail {
namespace RDF {
class RMergeableValueBase;
class RNodeBase;
} // namespace RDF
} // namespace Detail
} // namespace ROOT
Expand All @@ -39,10 +40,12 @@ class GraphNode;
class RJittedAction : public RActionBase {
private:
std::unique_ptr<RActionBase> fConcreteAction;
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> fPrevNode;

public:
RJittedAction(RLoopManager &lm, const ROOT::RDF::ColumnNames_t &columns, const RColumnRegister &colRegister,
const std::vector<std::string> &prevVariations);
const std::vector<std::string> &prevVariations,
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> prevNode = nullptr);
~RJittedAction();

void SetAction(std::unique_ptr<RActionBase> a) { fConcreteAction = std::move(a); }
Expand All @@ -67,6 +70,7 @@ public:

std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) final;
std::unique_ptr<ROOT::Internal::RDF::RActionBase> CloneAction(void *newResult) final;
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> MoveOutPrevNode();
};

} // ns RDF
Expand Down
Loading
Loading