Skip to content

Commit f776318

Browse files
committed
fix aggregate alias fetch
1 parent b7f3a46 commit f776318

File tree

1 file changed

+118
-30
lines changed

1 file changed

+118
-30
lines changed

src/alerts/mod.rs

Lines changed: 118 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -741,12 +741,32 @@ fn _get_aggregate_projection(plan: &LogicalPlan) -> Result<String, AlertError> {
741741
pub fn extract_aggregate_aliases(plan: &LogicalPlan) -> Vec<(String, Option<String>)> {
742742
let mut aliases = Vec::new();
743743

744-
if let LogicalPlan::Projection(projection) = plan {
745-
// Check if this projection contains aliased aggregates
746-
for expr in &projection.expr {
747-
if let Some((agg_name, alias)) = extract_alias_from_expr(expr) {
748-
aliases.push((agg_name, alias));
744+
// Handle different logical plan node types
745+
match plan {
746+
LogicalPlan::Projection(projection) => {
747+
for expr in &projection.expr {
748+
if let Some((agg_name, alias)) = extract_alias_from_expr(expr) {
749+
aliases.push((agg_name, alias));
750+
}
751+
}
752+
}
753+
LogicalPlan::Aggregate(aggregate) => {
754+
// Check aggregate expressions directly
755+
for expr in &aggregate.aggr_expr {
756+
if let Some((agg_name, alias)) = extract_alias_from_expr(expr) {
757+
aliases.push((agg_name, alias));
758+
}
749759
}
760+
761+
// Also check group expressions in case they contain aggregates
762+
for expr in &aggregate.group_expr {
763+
if let Some((agg_name, alias)) = extract_alias_from_expr(expr) {
764+
aliases.push((agg_name, alias));
765+
}
766+
}
767+
}
768+
_ => {
769+
// For other node types, continue traversal
750770
}
751771
}
752772

@@ -762,42 +782,43 @@ pub fn extract_aggregate_aliases(plan: &LogicalPlan) -> Vec<(String, Option<Stri
762782
fn extract_alias_from_expr(expr: &Expr) -> Option<(String, Option<String>)> {
763783
match expr {
764784
Expr::Alias(alias_expr) => {
765-
// This is an aliased expression
766785
let alias_name = alias_expr.name.clone();
767786

768-
if let Expr::AggregateFunction(agg_func) = alias_expr.expr.as_ref() {
769-
let agg_name = format!("{:?}", agg_func.func);
787+
// Check if the aliased expression is an aggregate
788+
if let Some((agg_name, _)) = extract_alias_from_expr(&alias_expr.expr) {
770789
Some((agg_name, Some(alias_name)))
771790
} else {
772-
// Handle other aggregate expressions like Count, etc.
773-
// Check if the inner expression is an aggregate
774-
let expr_str = format!("{:?}", alias_expr.expr);
775-
if expr_str.contains("count")
776-
|| expr_str.contains("sum")
777-
|| expr_str.contains("avg")
778-
|| expr_str.contains("min")
779-
|| expr_str.contains("max")
780-
{
781-
Some((expr_str, Some(alias_name)))
782-
} else {
783-
None
784-
}
791+
None
785792
}
786793
}
787794
Expr::AggregateFunction(agg_func) => {
788-
// Unaliased aggregate function
789795
let agg_name = format!("{:?}", agg_func.func);
790796
Some((agg_name, None))
791797
}
798+
// Handle specific aggregate function variants
799+
Expr::WindowFunction(window_func) => {
800+
// Some aggregates might appear as window functions
801+
let func_name = format!("{:?}", window_func.fun);
802+
if is_aggregate_function(&func_name) {
803+
Some((func_name, None))
804+
} else {
805+
None
806+
}
807+
}
808+
// Handle built-in aggregate functions that might not be AggregateFunction
809+
Expr::ScalarFunction(scalar_func) => {
810+
let func_name = format!("{:?}", scalar_func.func);
811+
if is_aggregate_function(&func_name) {
812+
Some((func_name, None))
813+
} else {
814+
None
815+
}
816+
}
792817
Expr::Column(column_expr) => {
793-
// This might be an un-aliased aggregate expression
794-
if column_expr.name().contains("count")
795-
|| column_expr.name().contains("sum")
796-
|| column_expr.name().contains("avg")
797-
|| column_expr.name().contains("min")
798-
|| column_expr.name().contains("max")
799-
{
800-
Some((column_expr.name.clone(), None))
818+
// Check if column name suggests it's an aggregate result
819+
let column_name = column_expr.name();
820+
if is_likely_aggregate_column(column_name) {
821+
Some((column_name.to_owned(), None))
801822
} else {
802823
None
803824
}
@@ -806,6 +827,73 @@ fn extract_alias_from_expr(expr: &Expr) -> Option<(String, Option<String>)> {
806827
}
807828
}
808829

830+
/// Helper function to determine if a function name represents an aggregate
831+
fn is_aggregate_function(func_name: &str) -> bool {
832+
let lower_func = func_name.to_lowercase();
833+
matches!(
834+
lower_func.as_str(),
835+
"count"
836+
| "sum"
837+
| "avg"
838+
| "mean"
839+
| "min"
840+
| "max"
841+
| "stddev"
842+
| "variance"
843+
| "first"
844+
| "last"
845+
| "array_agg"
846+
| "string_agg"
847+
| "bit_and"
848+
| "bit_or"
849+
| "bit_xor"
850+
) || lower_func.contains("count")
851+
|| lower_func.contains("sum")
852+
|| lower_func.contains("avg")
853+
|| lower_func.contains("min")
854+
|| lower_func.contains("max")
855+
}
856+
857+
/// Helper function to determine if a column name suggests it's an aggregate result
858+
fn is_likely_aggregate_column(column_name: &str) -> bool {
859+
let lower_name = column_name.to_lowercase();
860+
lower_name.starts_with("count_")
861+
|| lower_name.starts_with("sum_")
862+
|| lower_name.starts_with("avg_")
863+
|| lower_name.starts_with("min_")
864+
|| lower_name.starts_with("max_")
865+
|| lower_name.contains("count(")
866+
|| lower_name.contains("sum(")
867+
|| lower_name.contains("avg(")
868+
|| lower_name.contains("min(")
869+
|| lower_name.contains("max(")
870+
}
871+
872+
/// Alternative approach: Walk the entire plan and collect all expressions
873+
pub fn extract_aggregate_aliases_comprehensive(
874+
plan: &LogicalPlan,
875+
) -> Vec<(String, Option<String>)> {
876+
let mut aliases = Vec::new();
877+
collect_all_expressions(plan, &mut aliases);
878+
aliases.dedup(); // Remove duplicates
879+
aliases
880+
}
881+
882+
fn collect_all_expressions(plan: &LogicalPlan, aliases: &mut Vec<(String, Option<String>)>) {
883+
// Collect expressions from current node
884+
let expressions = plan.expressions();
885+
for expr in expressions {
886+
if let Some((agg_name, alias)) = extract_alias_from_expr(&expr) {
887+
aliases.push((agg_name, alias));
888+
}
889+
}
890+
891+
// Recursively process child plans
892+
for input in plan.inputs() {
893+
collect_all_expressions(input, aliases);
894+
}
895+
}
896+
809897
/// Analyze a logical plan to determine if it represents an aggregate query
810898
///
811899
/// Returns the number of aggregate expressions found in the plan

0 commit comments

Comments
 (0)