@@ -741,12 +741,32 @@ fn _get_aggregate_projection(plan: &LogicalPlan) -> Result<String, AlertError> {
741
741
pub fn extract_aggregate_aliases ( plan : & LogicalPlan ) -> Vec < ( String , Option < String > ) > {
742
742
let mut aliases = Vec :: new ( ) ;
743
743
744
- if let LogicalPlan :: Projection ( projection) = plan {
745
- // Check if this projection contains aliased aggregates
746
- for expr in & projection. expr {
747
- if let Some ( ( agg_name, alias) ) = extract_alias_from_expr ( expr) {
748
- aliases. push ( ( agg_name, alias) ) ;
744
+ // Handle different logical plan node types
745
+ match plan {
746
+ LogicalPlan :: Projection ( projection) => {
747
+ for expr in & projection. expr {
748
+ if let Some ( ( agg_name, alias) ) = extract_alias_from_expr ( expr) {
749
+ aliases. push ( ( agg_name, alias) ) ;
750
+ }
751
+ }
752
+ }
753
+ LogicalPlan :: Aggregate ( aggregate) => {
754
+ // Check aggregate expressions directly
755
+ for expr in & aggregate. aggr_expr {
756
+ if let Some ( ( agg_name, alias) ) = extract_alias_from_expr ( expr) {
757
+ aliases. push ( ( agg_name, alias) ) ;
758
+ }
749
759
}
760
+
761
+ // Also check group expressions in case they contain aggregates
762
+ for expr in & aggregate. group_expr {
763
+ if let Some ( ( agg_name, alias) ) = extract_alias_from_expr ( expr) {
764
+ aliases. push ( ( agg_name, alias) ) ;
765
+ }
766
+ }
767
+ }
768
+ _ => {
769
+ // For other node types, continue traversal
750
770
}
751
771
}
752
772
@@ -762,42 +782,43 @@ pub fn extract_aggregate_aliases(plan: &LogicalPlan) -> Vec<(String, Option<Stri
762
782
fn extract_alias_from_expr ( expr : & Expr ) -> Option < ( String , Option < String > ) > {
763
783
match expr {
764
784
Expr :: Alias ( alias_expr) => {
765
- // This is an aliased expression
766
785
let alias_name = alias_expr. name . clone ( ) ;
767
786
768
- if let Expr :: AggregateFunction ( agg_func ) = alias_expr . expr . as_ref ( ) {
769
- let agg_name = format ! ( "{:?}" , agg_func . func ) ;
787
+ // Check if the aliased expression is an aggregate
788
+ if let Some ( ( agg_name, _ ) ) = extract_alias_from_expr ( & alias_expr . expr ) {
770
789
Some ( ( agg_name, Some ( alias_name) ) )
771
790
} else {
772
- // Handle other aggregate expressions like Count, etc.
773
- // Check if the inner expression is an aggregate
774
- let expr_str = format ! ( "{:?}" , alias_expr. expr) ;
775
- if expr_str. contains ( "count" )
776
- || expr_str. contains ( "sum" )
777
- || expr_str. contains ( "avg" )
778
- || expr_str. contains ( "min" )
779
- || expr_str. contains ( "max" )
780
- {
781
- Some ( ( expr_str, Some ( alias_name) ) )
782
- } else {
783
- None
784
- }
791
+ None
785
792
}
786
793
}
787
794
Expr :: AggregateFunction ( agg_func) => {
788
- // Unaliased aggregate function
789
795
let agg_name = format ! ( "{:?}" , agg_func. func) ;
790
796
Some ( ( agg_name, None ) )
791
797
}
798
+ // Handle specific aggregate function variants
799
+ Expr :: WindowFunction ( window_func) => {
800
+ // Some aggregates might appear as window functions
801
+ let func_name = format ! ( "{:?}" , window_func. fun) ;
802
+ if is_aggregate_function ( & func_name) {
803
+ Some ( ( func_name, None ) )
804
+ } else {
805
+ None
806
+ }
807
+ }
808
+ // Handle built-in aggregate functions that might not be AggregateFunction
809
+ Expr :: ScalarFunction ( scalar_func) => {
810
+ let func_name = format ! ( "{:?}" , scalar_func. func) ;
811
+ if is_aggregate_function ( & func_name) {
812
+ Some ( ( func_name, None ) )
813
+ } else {
814
+ None
815
+ }
816
+ }
792
817
Expr :: Column ( column_expr) => {
793
- // This might be an un-aliased aggregate expression
794
- if column_expr. name ( ) . contains ( "count" )
795
- || column_expr. name ( ) . contains ( "sum" )
796
- || column_expr. name ( ) . contains ( "avg" )
797
- || column_expr. name ( ) . contains ( "min" )
798
- || column_expr. name ( ) . contains ( "max" )
799
- {
800
- Some ( ( column_expr. name . clone ( ) , None ) )
818
+ // Check if column name suggests it's an aggregate result
819
+ let column_name = column_expr. name ( ) ;
820
+ if is_likely_aggregate_column ( column_name) {
821
+ Some ( ( column_name. to_owned ( ) , None ) )
801
822
} else {
802
823
None
803
824
}
@@ -806,6 +827,73 @@ fn extract_alias_from_expr(expr: &Expr) -> Option<(String, Option<String>)> {
806
827
}
807
828
}
808
829
830
+ /// Helper function to determine if a function name represents an aggregate
831
+ fn is_aggregate_function ( func_name : & str ) -> bool {
832
+ let lower_func = func_name. to_lowercase ( ) ;
833
+ matches ! (
834
+ lower_func. as_str( ) ,
835
+ "count"
836
+ | "sum"
837
+ | "avg"
838
+ | "mean"
839
+ | "min"
840
+ | "max"
841
+ | "stddev"
842
+ | "variance"
843
+ | "first"
844
+ | "last"
845
+ | "array_agg"
846
+ | "string_agg"
847
+ | "bit_and"
848
+ | "bit_or"
849
+ | "bit_xor"
850
+ ) || lower_func. contains ( "count" )
851
+ || lower_func. contains ( "sum" )
852
+ || lower_func. contains ( "avg" )
853
+ || lower_func. contains ( "min" )
854
+ || lower_func. contains ( "max" )
855
+ }
856
+
857
+ /// Helper function to determine if a column name suggests it's an aggregate result
858
+ fn is_likely_aggregate_column ( column_name : & str ) -> bool {
859
+ let lower_name = column_name. to_lowercase ( ) ;
860
+ lower_name. starts_with ( "count_" )
861
+ || lower_name. starts_with ( "sum_" )
862
+ || lower_name. starts_with ( "avg_" )
863
+ || lower_name. starts_with ( "min_" )
864
+ || lower_name. starts_with ( "max_" )
865
+ || lower_name. contains ( "count(" )
866
+ || lower_name. contains ( "sum(" )
867
+ || lower_name. contains ( "avg(" )
868
+ || lower_name. contains ( "min(" )
869
+ || lower_name. contains ( "max(" )
870
+ }
871
+
872
+ /// Alternative approach: Walk the entire plan and collect all expressions
873
+ pub fn extract_aggregate_aliases_comprehensive (
874
+ plan : & LogicalPlan ,
875
+ ) -> Vec < ( String , Option < String > ) > {
876
+ let mut aliases = Vec :: new ( ) ;
877
+ collect_all_expressions ( plan, & mut aliases) ;
878
+ aliases. dedup ( ) ; // Remove duplicates
879
+ aliases
880
+ }
881
+
882
+ fn collect_all_expressions ( plan : & LogicalPlan , aliases : & mut Vec < ( String , Option < String > ) > ) {
883
+ // Collect expressions from current node
884
+ let expressions = plan. expressions ( ) ;
885
+ for expr in expressions {
886
+ if let Some ( ( agg_name, alias) ) = extract_alias_from_expr ( & expr) {
887
+ aliases. push ( ( agg_name, alias) ) ;
888
+ }
889
+ }
890
+
891
+ // Recursively process child plans
892
+ for input in plan. inputs ( ) {
893
+ collect_all_expressions ( input, aliases) ;
894
+ }
895
+ }
896
+
809
897
/// Analyze a logical plan to determine if it represents an aggregate query
810
898
///
811
899
/// Returns the number of aggregate expressions found in the plan
0 commit comments