@@ -267,10 +267,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
267267 const core::QueryConfig& config)
268268 -> std::unique_ptr<Aggregate> {
269269 if (auto func = getAggregateFunctionEntry (name)) {
270+ core::AggregationNode::Step usedStep{
271+ core::AggregationNode::Step::kPartial };
270272 if (!exec::isRawInput (step)) {
271- step = core::AggregationNode::Step::kIntermediate ;
273+ usedStep = core::AggregationNode::Step::kIntermediate ;
272274 }
273- auto fn = func->factory (step, argTypes, resultType, config);
275+ auto fn =
276+ func->factory (usedStep, argTypes, resultType, config);
274277 VELOX_CHECK_NOT_NULL (fn);
275278 return std::make_unique<
276279 AggregateCompanionAdapter::PartialFunction>(
@@ -409,26 +412,51 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
409412 const std::vector<AggregateFunctionSignaturePtr>& signatures,
410413 const AggregateFunctionMetadata& metadata,
411414 bool overwrite) {
415+ bool registered = false ;
412416 if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures (
413417 signatures)) {
414- return registerMergeExtractFunctionWithSuffix (
415- name, signatures, metadata, overwrite);
418+ registered |=
419+ registerMergeExtractFunctionWithSuffix ( name, signatures, metadata, overwrite);
416420 }
417421
418422 auto mergeExtractSignatures =
419423 CompanionSignatures::mergeExtractFunctionSignatures (signatures);
420424 if (mergeExtractSignatures.empty ()) {
421- return false ;
425+ return registered ;
422426 }
423427
424428 auto mergeExtractFunctionName =
425429 CompanionSignatures::mergeExtractFunctionName (name);
426- return registerMergeExtractFunctionInternal (
427- name,
428- mergeExtractFunctionName,
429- std::move (mergeExtractSignatures),
430- metadata,
431- overwrite);
430+ registered |=
431+ exec::registerAggregateFunction (
432+ mergeExtractFunctionName,
433+ std::move (mergeExtractSignatures),
434+ [name, mergeExtractFunctionName](
435+ core::AggregationNode::Step /* step*/ ,
436+ const std::vector<TypePtr>& argTypes,
437+ const TypePtr& resultType,
438+ const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
439+ if (auto func = getAggregateFunctionEntry (name)) {
440+ auto fn = func->factory (
441+ core::AggregationNode::Step::kFinal ,
442+ argTypes,
443+ resultType,
444+ config);
445+ VELOX_CHECK_NOT_NULL (fn);
446+ return std::make_unique<
447+ AggregateCompanionAdapter::MergeExtractFunction>(
448+ std::move (fn), resultType);
449+ }
450+ VELOX_FAIL (
451+ " Original aggregation function {} not found: {}" ,
452+ name,
453+ mergeExtractFunctionName);
454+ },
455+ metadata,
456+ /* registerCompanionFunctions*/ false ,
457+ overwrite)
458+ .mainFunction ;
459+ return registered;
432460}
433461
434462VectorFunctionFactory getVectorFunctionFactory (
0 commit comments