diff --git a/xlsynth-g8r/src/gatify/prep_for_gatify.rs b/xlsynth-g8r/src/gatify/prep_for_gatify.rs index 5b054a8e..c220d84d 100644 --- a/xlsynth-g8r/src/gatify/prep_for_gatify.rs +++ b/xlsynth-g8r/src/gatify/prep_for_gatify.rs @@ -187,6 +187,127 @@ fn rewrite_encode_one_hot_to_ext_prio_encode(f: &mut ir::Fn) -> usize { rewrites } +/// Rewrites the shape: +/// +/// ```text +/// bit_slice( +/// and(bit_slice(x, start=0, width=low_w), sign_ext(not(sle(x, 0)), low_w)), +/// start=low_w-1, +/// width=1) +/// ``` +/// +/// into: +/// +/// ```text +/// and(bit_slice(x, start=low_w-1, width=1), not(bit_slice(x, start=x_w-1, width=1))) +/// ``` +/// +/// Rationale: `not(sle(x, 0))` is `x > 0`, i.e. `(!sign(x)) & (x != 0)`. +/// When already ANDed with a non-sign data bit `x[k]` (where `k < x_w-1`), +/// the `x != 0` predicate is redundant because `x[k] == 1` implies non-zero. +/// So the mask collapses to just `!sign(x)` for that projected bit. +fn rewrite_projected_positive_mask_to_sign_test(f: &mut ir::Fn) -> usize { + let mut rewrites: usize = 0; + let original_len = f.nodes.len(); + + for node_index in 0..original_len { + let payload = f.nodes[node_index].payload.clone(); + let NodePayload::BitSlice { + arg: and_nr, + start, + width, + } = payload + else { + continue; + }; + if width != 1 { + continue; + } + + let NodePayload::Nary(NaryOp::And, and_operands) = f.nodes[and_nr.index].payload.clone() + else { + continue; + }; + if and_operands.len() != 2 { + continue; + } + + let first = and_operands[0]; + let second = and_operands[1]; + let (low_bits, sign_ext) = { + let first_payload = f.nodes[first.index].payload.clone(); + let second_payload = f.nodes[second.index].payload.clone(); + match (first_payload, second_payload) { + ( + NodePayload::BitSlice { + arg, + start: low_start, + width: low_width, + }, + NodePayload::SignExt { + arg: pred, + new_bit_count, + }, + ) if low_start == 0 && low_width == new_bit_count => (arg, (pred, new_bit_count)), + ( + NodePayload::SignExt { + arg: pred, + new_bit_count, + }, + NodePayload::BitSlice { + arg, + start: low_start, + width: low_width, + }, + ) if low_start == 0 && low_width == new_bit_count => (arg, (pred, new_bit_count)), + _ => continue, + } + }; + + let (pred_nr, low_w) = sign_ext; + if start != low_w.saturating_sub(1) || low_w == 0 { + continue; + } + + let x_w = f.nodes[low_bits.index].ty.bit_count(); + if x_w != low_w.saturating_add(1) { + continue; + } + + let NodePayload::Unop(Unop::Not, sle_nr) = f.nodes[pred_nr.index].payload.clone() else { + continue; + }; + let NodePayload::Binop(Binop::Sle, lhs, rhs) = f.nodes[sle_nr.index].payload.clone() else { + continue; + }; + if lhs != low_bits { + continue; + } + if is_ubits_literal_0_or_1_of_width(f, rhs, x_w) != Some(false) { + continue; + } + + // Keep node ordering valid by rewriting the existing `sle` node into a + // sign-bit extract. The following existing `not` and `sign_ext` nodes + // then become `!sign(x)` and its replication, respectively. + ir_utils::replace_node_payload( + f, + sle_nr, + NodePayload::BitSlice { + arg: lhs, + start: x_w - 1, + width: 1, + }, + Some(Type::Bits(1)), + ) + .expect("prep_for_gatify: projected positive-mask rewrite failed"); + + rewrites += 1; + } + + rewrites +} + fn nil_out_node(f: &mut ir::Fn, node_ref: ir::NodeRef) { let node = &mut f.nodes[node_ref.index]; node.payload = NodePayload::Nil; @@ -581,6 +702,7 @@ pub fn prep_for_gatify( if options.enable_rewrite_prio_encode { let _rewrites = rewrite_encode_one_hot_to_ext_prio_encode(&mut cloned); } + let _rewrites = rewrite_projected_positive_mask_to_sign_test(&mut cloned); mark_dead_nodes_as_nil(&mut cloned); cloned } @@ -717,4 +839,50 @@ top fn cone(p0: bits[9] id=1, p1: bits[9] id=2) -> bits[1] { optimized_text ); } + + #[test] + fn projected_positive_mask_rewrite_collapses_to_sign_test_and_data_bit() { + let ir_text = "package sample + +fn cone(x: bits[13] id=1) -> bits[1] { + z: bits[13] = literal(value=0, id=2) + le0: bits[1] = sle(x, z, id=3) + gt0: bits[1] = not(le0, id=4) + low: bits[12] = bit_slice(x, start=0, width=12, id=5) + m: bits[12] = sign_ext(gt0, new_bit_count=12, id=6) + gated: bits[12] = and(low, m, id=7) + ret out: bits[1] = bit_slice(gated, start=11, width=1, id=8) +} +"; + let mut parser = Parser::new(ir_text); + let pkg = parser.parse_and_validate_package().unwrap(); + let f = pkg.get_top_fn().unwrap(); + + let optimized = prep_for_gatify(f, None, PrepForGatifyOptions::default()); + let optimized_text = optimized.to_string(); + assert!( + optimized_text.contains("and(") && optimized_text.contains("start=12, width=1"), + "expected rewrite to project sign-bit guard; got:\n{}", + optimized_text + ); + assert!( + !optimized_text.contains("sle(") || optimized_text.lines().all(|l| l.contains("Nil")), + "expected SLE cone to be removed from the live logic; got:\n{}", + optimized_text + ); + + // Equivalence sanity check over all representable values for bits[13]. + for x in 0u64..(1u64 << 13) { + let args = [IrValue::make_ubits(13, x).unwrap()]; + let got_orig = match eval_fn(f, &args) { + FnEvalResult::Success(s) => s.value.to_bool().unwrap(), + FnEvalResult::Failure(f) => panic!("unexpected eval failure: {:?}", f), + }; + let got_opt = match eval_fn(&optimized, &args) { + FnEvalResult::Success(s) => s.value.to_bool().unwrap(), + FnEvalResult::Failure(f) => panic!("unexpected eval failure: {:?}", f), + }; + assert_eq!(got_orig, got_opt, "mismatch at x={}", x); + } + } }