@@ -136,12 +136,18 @@ impl StatsRewriteRule for BoundedBinaryStatsRewrite {
136136 fn falsify (
137137 & self ,
138138 expr : & Expression ,
139- _ctx : & StatsRewriteCtx < ' _ > ,
139+ ctx : & StatsRewriteCtx < ' _ > ,
140140 ) -> VortexResult < Option < Expression > > {
141141 let operator = expr. as_ :: < Binary > ( ) ;
142142 let lhs = expr. child ( 0 ) ;
143143 let rhs = expr. child ( 1 ) ;
144144
145+ if !supports_bounded_rewrite ( & ctx. return_dtype ( lhs) ?)
146+ || !supports_bounded_rewrite ( & ctx. return_dtype ( rhs) ?)
147+ {
148+ return Ok ( None ) ;
149+ }
150+
145151 Ok ( match operator {
146152 Operator :: Eq => {
147153 let left = gt ( bounded_min ( lhs) , bounded_max ( rhs) ) ;
@@ -432,6 +438,10 @@ fn has_nans(dtype: &DType) -> bool {
432438 matches ! ( dtype, DType :: Primitive ( ptype, _) if ptype. is_float( ) )
433439}
434440
441+ fn supports_bounded_rewrite ( dtype : & DType ) -> bool {
442+ matches ! ( dtype, DType :: Binary ( _) | DType :: Utf8 ( _) )
443+ }
444+
435445fn stat_expr ( expr : & Expression , stat : Stat ) -> Option < Expression > {
436446 if let Some ( literal) = literal_stat ( expr, stat) {
437447 return Some ( literal) ;
@@ -570,6 +580,8 @@ mod tests {
570580 StructFields :: from_iter ( [
571581 ( "a" , DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ) ,
572582 ( "b" , DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ) ,
583+ ( "s" , DType :: Utf8 ( Nullability :: NonNullable ) ) ,
584+ ( "t" , DType :: Utf8 ( Nullability :: NonNullable ) ) ,
573585 ] ) ,
574586 Nullability :: NonNullable ,
575587 )
@@ -592,29 +604,38 @@ mod tests {
592604 let expr = gt ( col ( "a" ) , lit ( 10 ) ) ;
593605 assert_eq ! (
594606 falsify( & expr) ?,
595- Some ( or(
596- and(
597- nan_free( col( "a" ) ) ,
598- lt_eq( stat( col( "a" ) , Stat :: Max ) , lit( 10 ) ) ,
599- ) ,
600- lt_eq( bounded_max( & col( "a" ) ) , lit( 10 ) ) ,
607+ Some ( and(
608+ nan_free( col( "a" ) ) ,
609+ lt_eq( stat( col( "a" ) , Stat :: Max ) , lit( 10 ) ) ,
601610 ) )
602611 ) ;
603612
604613 let expr = eq ( col ( "a" ) , col ( "b" ) ) ;
614+ assert_eq ! (
615+ falsify( & expr) ?,
616+ Some ( and(
617+ and( nan_free( col( "a" ) ) , nan_free( col( "b" ) ) ) ,
618+ or(
619+ gt( stat( col( "a" ) , Stat :: Min ) , stat( col( "b" ) , Stat :: Max ) ) ,
620+ gt( stat( col( "b" ) , Stat :: Min ) , stat( col( "a" ) , Stat :: Max ) ) ,
621+ ) ,
622+ ) )
623+ ) ;
624+
625+ let expr = eq ( col ( "s" ) , col ( "t" ) ) ;
605626 assert_eq ! (
606627 falsify( & expr) ?,
607628 Some ( or(
608629 and(
609- and( nan_free( col( "a " ) ) , nan_free( col( "b " ) ) ) ,
630+ and( nan_free( col( "s " ) ) , nan_free( col( "t " ) ) ) ,
610631 or(
611- gt( stat( col( "a " ) , Stat :: Min ) , stat( col( "b " ) , Stat :: Max ) ) ,
612- gt( stat( col( "b " ) , Stat :: Min ) , stat( col( "a " ) , Stat :: Max ) ) ,
632+ gt( stat( col( "s " ) , Stat :: Min ) , stat( col( "t " ) , Stat :: Max ) ) ,
633+ gt( stat( col( "t " ) , Stat :: Min ) , stat( col( "s " ) , Stat :: Max ) ) ,
613634 ) ,
614635 ) ,
615636 or(
616- gt( bounded_min( & col( "a " ) ) , bounded_max( & col( "b " ) ) ) ,
617- gt( bounded_min( & col( "b " ) ) , bounded_max( & col( "a " ) ) ) ,
637+ gt( bounded_min( & col( "s " ) ) , bounded_max( & col( "t " ) ) ) ,
638+ gt( bounded_min( & col( "t " ) ) , bounded_max( & col( "s " ) ) ) ,
618639 ) ,
619640 ) )
620641 ) ;
@@ -627,19 +648,13 @@ mod tests {
627648 assert_eq ! (
628649 falsify( & expr) ?,
629650 Some ( or(
630- or(
631- and(
632- nan_free( col( "a" ) ) ,
633- lt_eq( stat( col( "a" ) , Stat :: Max ) , lit( 10 ) ) ,
634- ) ,
635- lt_eq( bounded_max( & col( "a" ) ) , lit( 10 ) ) ,
651+ and(
652+ nan_free( col( "a" ) ) ,
653+ lt_eq( stat( col( "a" ) , Stat :: Max ) , lit( 10 ) ) ,
636654 ) ,
637- or(
638- and(
639- nan_free( col( "a" ) ) ,
640- gt_eq( stat( col( "a" ) , Stat :: Min ) , lit( 50 ) ) ,
641- ) ,
642- gt_eq( bounded_min( & col( "a" ) ) , lit( 50 ) ) ,
655+ and(
656+ nan_free( col( "a" ) ) ,
657+ gt_eq( stat( col( "a" ) , Stat :: Min ) , lit( 50 ) ) ,
643658 ) ,
644659 ) )
645660 ) ;
@@ -661,14 +676,8 @@ mod tests {
661676 assert_eq ! (
662677 falsify( & expr) ?,
663678 Some ( or(
664- or(
665- and( nan_free( col( "a" ) ) , gt( lit( 10 ) , stat( col( "a" ) , Stat :: Max ) ) ) ,
666- gt( lit( 10 ) , bounded_max( & col( "a" ) ) ) ,
667- ) ,
668- or(
669- and( nan_free( col( "a" ) ) , gt( stat( col( "a" ) , Stat :: Min ) , lit( 50 ) ) ) ,
670- gt( bounded_min( & col( "a" ) ) , lit( 50 ) ) ,
671- ) ,
679+ and( nan_free( col( "a" ) ) , gt( lit( 10 ) , stat( col( "a" ) , Stat :: Max ) ) ) ,
680+ and( nan_free( col( "a" ) ) , gt( stat( col( "a" ) , Stat :: Min ) , lit( 50 ) ) ) ,
672681 ) )
673682 ) ;
674683 Ok ( ( ) )
@@ -762,14 +771,8 @@ mod tests {
762771 assert_eq ! (
763772 falsify( & expr) ?,
764773 Some ( or(
765- or(
766- gt( cast( stat( col( "a" ) , Stat :: Min ) , dtype. clone( ) ) , lit( 42i64 ) ) ,
767- gt( lit( 42i64 ) , cast( stat( col( "a" ) , Stat :: Max ) , dtype. clone( ) ) ) ,
768- ) ,
769- or(
770- gt( cast( bounded_min( & col( "a" ) ) , dtype. clone( ) ) , lit( 42i64 ) ) ,
771- gt( lit( 42i64 ) , cast( bounded_max( & col( "a" ) ) , dtype) ) ,
772- ) ,
774+ gt( cast( stat( col( "a" ) , Stat :: Min ) , dtype. clone( ) ) , lit( 42i64 ) ) ,
775+ gt( lit( 42i64 ) , cast( stat( col( "a" ) , Stat :: Max ) , dtype) ) ,
773776 ) )
774777 ) ;
775778 Ok ( ( ) )
0 commit comments