Skip to content

Commit f14f6a1

Browse files
committed
Thread scope dtype through zoned pruning
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 16f09c1 commit f14f6a1

3 files changed

Lines changed: 57 additions & 48 deletions

File tree

vortex-array/src/stats/rewrite/builtins.rs

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,18 @@ impl StatsRewriteRule for BoundedBinaryStatsRewrite {
136136
fn falsify(
137137
&self,
138138
expr: &Expression,
139-
_ctx: &StatsRewriteCtx<'_>,
139+
ctx: &StatsRewriteCtx<'_>,
140140
) -> VortexResult<Option<Expression>> {
141141
let operator = expr.as_::<Binary>();
142142
let lhs = expr.child(0);
143143
let rhs = expr.child(1);
144144

145+
if !supports_bounded_rewrite(&ctx.return_dtype(lhs)?)
146+
|| !supports_bounded_rewrite(&ctx.return_dtype(rhs)?)
147+
{
148+
return Ok(None);
149+
}
150+
145151
Ok(match operator {
146152
Operator::Eq => {
147153
let left = gt(bounded_min(lhs), bounded_max(rhs));
@@ -432,6 +438,10 @@ fn has_nans(dtype: &DType) -> bool {
432438
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
433439
}
434440

441+
fn supports_bounded_rewrite(dtype: &DType) -> bool {
442+
matches!(dtype, DType::Binary(_) | DType::Utf8(_))
443+
}
444+
435445
fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
436446
if let Some(literal) = literal_stat(expr, stat) {
437447
return Some(literal);
@@ -570,6 +580,8 @@ mod tests {
570580
StructFields::from_iter([
571581
("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
572582
("b", DType::Primitive(PType::I32, Nullability::NonNullable)),
583+
("s", DType::Utf8(Nullability::NonNullable)),
584+
("t", DType::Utf8(Nullability::NonNullable)),
573585
]),
574586
Nullability::NonNullable,
575587
)
@@ -592,29 +604,38 @@ mod tests {
592604
let expr = gt(col("a"), lit(10));
593605
assert_eq!(
594606
falsify(&expr)?,
595-
Some(or(
596-
and(
597-
nan_free(col("a")),
598-
lt_eq(stat(col("a"), Stat::Max), lit(10)),
599-
),
600-
lt_eq(bounded_max(&col("a")), lit(10)),
607+
Some(and(
608+
nan_free(col("a")),
609+
lt_eq(stat(col("a"), Stat::Max), lit(10)),
601610
))
602611
);
603612

604613
let expr = eq(col("a"), col("b"));
614+
assert_eq!(
615+
falsify(&expr)?,
616+
Some(and(
617+
and(nan_free(col("a")), nan_free(col("b"))),
618+
or(
619+
gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)),
620+
gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)),
621+
),
622+
))
623+
);
624+
625+
let expr = eq(col("s"), col("t"));
605626
assert_eq!(
606627
falsify(&expr)?,
607628
Some(or(
608629
and(
609-
and(nan_free(col("a")), nan_free(col("b"))),
630+
and(nan_free(col("s")), nan_free(col("t"))),
610631
or(
611-
gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)),
612-
gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)),
632+
gt(stat(col("s"), Stat::Min), stat(col("t"), Stat::Max)),
633+
gt(stat(col("t"), Stat::Min), stat(col("s"), Stat::Max)),
613634
),
614635
),
615636
or(
616-
gt(bounded_min(&col("a")), bounded_max(&col("b"))),
617-
gt(bounded_min(&col("b")), bounded_max(&col("a"))),
637+
gt(bounded_min(&col("s")), bounded_max(&col("t"))),
638+
gt(bounded_min(&col("t")), bounded_max(&col("s"))),
618639
),
619640
))
620641
);
@@ -627,19 +648,13 @@ mod tests {
627648
assert_eq!(
628649
falsify(&expr)?,
629650
Some(or(
630-
or(
631-
and(
632-
nan_free(col("a")),
633-
lt_eq(stat(col("a"), Stat::Max), lit(10)),
634-
),
635-
lt_eq(bounded_max(&col("a")), lit(10)),
651+
and(
652+
nan_free(col("a")),
653+
lt_eq(stat(col("a"), Stat::Max), lit(10)),
636654
),
637-
or(
638-
and(
639-
nan_free(col("a")),
640-
gt_eq(stat(col("a"), Stat::Min), lit(50)),
641-
),
642-
gt_eq(bounded_min(&col("a")), lit(50)),
655+
and(
656+
nan_free(col("a")),
657+
gt_eq(stat(col("a"), Stat::Min), lit(50)),
643658
),
644659
))
645660
);
@@ -661,14 +676,8 @@ mod tests {
661676
assert_eq!(
662677
falsify(&expr)?,
663678
Some(or(
664-
or(
665-
and(nan_free(col("a")), gt(lit(10), stat(col("a"), Stat::Max))),
666-
gt(lit(10), bounded_max(&col("a"))),
667-
),
668-
or(
669-
and(nan_free(col("a")), gt(stat(col("a"), Stat::Min), lit(50))),
670-
gt(bounded_min(&col("a")), lit(50)),
671-
),
679+
and(nan_free(col("a")), gt(lit(10), stat(col("a"), Stat::Max))),
680+
and(nan_free(col("a")), gt(stat(col("a"), Stat::Min), lit(50))),
672681
))
673682
);
674683
Ok(())
@@ -762,14 +771,8 @@ mod tests {
762771
assert_eq!(
763772
falsify(&expr)?,
764773
Some(or(
765-
or(
766-
gt(cast(stat(col("a"), Stat::Min), dtype.clone()), lit(42i64)),
767-
gt(lit(42i64), cast(stat(col("a"), Stat::Max), dtype.clone())),
768-
),
769-
or(
770-
gt(cast(bounded_min(&col("a")), dtype.clone()), lit(42i64)),
771-
gt(lit(42i64), cast(bounded_max(&col("a")), dtype)),
772-
),
774+
gt(cast(stat(col("a"), Stat::Min), dtype.clone()), lit(42i64)),
775+
gt(lit(42i64), cast(stat(col("a"), Stat::Max), dtype)),
773776
))
774777
);
775778
Ok(())

vortex-layout/src/layouts/zoned/pruning.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl PruningState {
116116
self.pruning_predicates
117117
.entry(expr.clone())
118118
.or_default()
119-
.get_or_init(move || match expr.falsify(&self.session) {
119+
.get_or_init(move || match expr.falsify(&self.dtype, &self.session) {
120120
Ok(predicate) => predicate,
121121
Err(error) => {
122122
trace!(%expr, %error, "failed to construct stats rewrite predicate");

vortex-layout/src/layouts/zoned/zone_map.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,10 @@ mod tests {
269269
use vortex_array::arrays::PrimitiveArray;
270270
use vortex_array::arrays::StructArray;
271271
use vortex_array::assert_arrays_eq;
272+
use vortex_array::dtype::DType;
272273
use vortex_array::dtype::FieldNames;
273274
use vortex_array::dtype::PType;
275+
use vortex_array::expr::Expression;
274276
use vortex_array::expr::gt;
275277
use vortex_array::expr::gt_eq;
276278
use vortex_array::expr::is_not_null;
@@ -288,6 +290,10 @@ mod tests {
288290
use crate::layouts::zoned::zone_map::ZoneMap;
289291
use crate::test::SESSION;
290292

293+
fn falsify(expr: &Expression, dtype: DType) -> Expression {
294+
expr.falsify(&dtype, &SESSION).unwrap().unwrap()
295+
}
296+
291297
#[test]
292298
fn test_zone_map_prunes() {
293299
// Construct a zone map with 3 zones:
@@ -331,7 +337,7 @@ mod tests {
331337
// A >= 6
332338
// => A.max < 6
333339
let expr = gt_eq(root(), lit(6i32));
334-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
340+
let pruning_expr = falsify(&expr, PType::I32.into());
335341
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
336342
assert_arrays_eq!(
337343
mask.into_array(),
@@ -341,7 +347,7 @@ mod tests {
341347
// A > 5
342348
// => A.max <= 5
343349
let expr = gt(root(), lit(5i32));
344-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
350+
let pruning_expr = falsify(&expr, PType::I32.into());
345351
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
346352
assert_arrays_eq!(
347353
mask.into_array(),
@@ -351,7 +357,7 @@ mod tests {
351357
// A < 2
352358
// => A.min >= 2
353359
let expr = lt(root(), lit(2i32));
354-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
360+
let pruning_expr = falsify(&expr, PType::I32.into());
355361
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
356362
assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true]));
357363
}
@@ -372,7 +378,7 @@ mod tests {
372378
.unwrap();
373379

374380
let expr = is_not_null(root());
375-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
381+
let pruning_expr = falsify(&expr, PType::U64.into());
376382

377383
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
378384
assert_arrays_eq!(
@@ -458,7 +464,7 @@ mod tests {
458464
);
459465

460466
let expr = gt(root(), lit(5u64));
461-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
467+
let pruning_expr = falsify(&expr, PType::U64.into());
462468
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
463469
assert_arrays_eq!(
464470
mask.into_array(),
@@ -488,7 +494,7 @@ mod tests {
488494
.unwrap();
489495

490496
let expr = gt(root(), lit(5.0f32));
491-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
497+
let pruning_expr = falsify(&expr, PType::F32.into());
492498
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
493499
assert_arrays_eq!(
494500
mask.into_array(),
@@ -541,7 +547,7 @@ mod tests {
541547
.unwrap();
542548

543549
let expr = is_not_null(root());
544-
let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap();
550+
let pruning_expr = falsify(&expr, PType::U64.into());
545551

546552
// All three zones have length 4 (total rows = 12).
547553
let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();

0 commit comments

Comments
 (0)