@@ -267,10 +267,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
267267                 const  core::QueryConfig& config)
268268                 -> std::unique_ptr<Aggregate> {
269269               if  (auto  func = getAggregateFunctionEntry (name)) {
270+                  core::AggregationNode::Step usedStep{
271+                      core::AggregationNode::Step::kPartial };
270272                 if  (!exec::isRawInput (step)) {
271-                    step  = core::AggregationNode::Step::kIntermediate ;
273+                    usedStep  = core::AggregationNode::Step::kIntermediate ;
272274                 }
273-                  auto  fn = func->factory (step, argTypes, resultType, config);
275+                  auto  fn =
276+                      func->factory (usedStep, argTypes, resultType, config);
274277                 VELOX_CHECK_NOT_NULL (fn);
275278                 return  std::make_unique<
276279                     AggregateCompanionAdapter::PartialFunction>(
@@ -409,26 +412,51 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
409412    const  std::vector<AggregateFunctionSignaturePtr>& signatures,
410413    const  AggregateFunctionMetadata& metadata,
411414    bool  overwrite) {
415+   bool  registered = false ;
412416  if  (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures (
413417          signatures)) {
414-     return   registerMergeExtractFunctionWithSuffix ( 
415-         name, signatures, metadata, overwrite);
418+     registered |= 
419+         registerMergeExtractFunctionWithSuffix ( name, signatures, metadata, overwrite);
416420  }
417421
418422  auto  mergeExtractSignatures =
419423      CompanionSignatures::mergeExtractFunctionSignatures (signatures);
420424  if  (mergeExtractSignatures.empty ()) {
421-     return  false ;
425+     return  registered ;
422426  }
423427
424428  auto  mergeExtractFunctionName =
425429      CompanionSignatures::mergeExtractFunctionName (name);
426-   return  registerMergeExtractFunctionInternal (
427-       name,
428-       mergeExtractFunctionName,
429-       std::move (mergeExtractSignatures),
430-       metadata,
431-       overwrite);
430+   registered |=
431+       exec::registerAggregateFunction (
432+           mergeExtractFunctionName,
433+           std::move (mergeExtractSignatures),
434+           [name, mergeExtractFunctionName](
435+               core::AggregationNode::Step /* step*/ 
436+               const  std::vector<TypePtr>& argTypes,
437+               const  TypePtr& resultType,
438+               const  core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
439+             if  (auto  func = getAggregateFunctionEntry (name)) {
440+               auto  fn = func->factory (
441+                   core::AggregationNode::Step::kFinal ,
442+                   argTypes,
443+                   resultType,
444+                   config);
445+               VELOX_CHECK_NOT_NULL (fn);
446+               return  std::make_unique<
447+                   AggregateCompanionAdapter::MergeExtractFunction>(
448+                   std::move (fn), resultType);
449+             }
450+             VELOX_FAIL (
451+                 " Original aggregation function {} not found: {}" 
452+                 name,
453+                 mergeExtractFunctionName);
454+           },
455+           metadata,
456+           /* registerCompanionFunctions*/ false ,
457+           overwrite)
458+           .mainFunction ;
459+   return  registered;
432460}
433461
434462VectorFunctionFactory getVectorFunctionFactory (
0 commit comments