Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions xlsynth-g8r/src/gatify/prep_for_gatify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,127 @@ fn rewrite_encode_one_hot_to_ext_prio_encode(f: &mut ir::Fn) -> usize {
rewrites
}

/// Rewrites the shape:
///
/// ```text
/// bit_slice(
/// and(bit_slice(x, start=0, width=low_w), sign_ext(not(sle(x, 0)), low_w)),
/// start=low_w-1,
/// width=1)
/// ```
///
/// into:
///
/// ```text
/// and(bit_slice(x, start=low_w-1, width=1), not(bit_slice(x, start=x_w-1, width=1)))
/// ```
///
/// Rationale: `not(sle(x, 0))` is `x > 0`, i.e. `(!sign(x)) & (x != 0)`.
/// When already ANDed with a non-sign data bit `x[k]` (where `k < x_w-1`),
/// the `x != 0` predicate is redundant because `x[k] == 1` implies non-zero.
/// So the mask collapses to just `!sign(x)` for that projected bit.
fn rewrite_projected_positive_mask_to_sign_test(f: &mut ir::Fn) -> usize {
let mut rewrites: usize = 0;
let original_len = f.nodes.len();

for node_index in 0..original_len {
let payload = f.nodes[node_index].payload.clone();
let NodePayload::BitSlice {
arg: and_nr,
start,
width,
} = payload
else {
continue;
};
if width != 1 {
continue;
}

let NodePayload::Nary(NaryOp::And, and_operands) = f.nodes[and_nr.index].payload.clone()
else {
continue;
};
if and_operands.len() != 2 {
continue;
}

let first = and_operands[0];
let second = and_operands[1];
let (low_bits, sign_ext) = {
let first_payload = f.nodes[first.index].payload.clone();
let second_payload = f.nodes[second.index].payload.clone();
match (first_payload, second_payload) {
(
NodePayload::BitSlice {
arg,
start: low_start,
width: low_width,
},
NodePayload::SignExt {
arg: pred,
new_bit_count,
},
) if low_start == 0 && low_width == new_bit_count => (arg, (pred, new_bit_count)),
(
NodePayload::SignExt {
arg: pred,
new_bit_count,
},
NodePayload::BitSlice {
arg,
start: low_start,
width: low_width,
},
) if low_start == 0 && low_width == new_bit_count => (arg, (pred, new_bit_count)),
_ => continue,
}
};

let (pred_nr, low_w) = sign_ext;
if start != low_w.saturating_sub(1) || low_w == 0 {
continue;
}

let x_w = f.nodes[low_bits.index].ty.bit_count();
if x_w != low_w.saturating_add(1) {
continue;
}

let NodePayload::Unop(Unop::Not, sle_nr) = f.nodes[pred_nr.index].payload.clone() else {
continue;
};
let NodePayload::Binop(Binop::Sle, lhs, rhs) = f.nodes[sle_nr.index].payload.clone() else {
continue;
};
if lhs != low_bits {
continue;
}
if is_ubits_literal_0_or_1_of_width(f, rhs, x_w) != Some(false) {
continue;
}

// Keep node ordering valid by rewriting the existing `sle` node into a
// sign-bit extract. The following existing `not` and `sign_ext` nodes
// then become `!sign(x)` and its replication, respectively.
ir_utils::replace_node_payload(
f,
sle_nr,
NodePayload::BitSlice {
Comment on lines +293 to +296

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Require unique use before mutating sle into sign bit

This rewrite replaces sle_nr in place with a sign-bit bit_slice, but there is no use-count guard that the matched sle/not chain is only used by this cone. If sle(x,0) is also consumed elsewhere, those users now see sign(x) instead of x <= 0, causing incorrect behavior. Add single-use checks (like other rewrites in this file) before mutating shared nodes.

Useful? React with 👍 / 👎.

arg: lhs,
start: x_w - 1,
width: 1,
},
Some(Type::Bits(1)),
)
.expect("prep_for_gatify: projected positive-mask rewrite failed");

rewrites += 1;
}

rewrites
}

fn nil_out_node(f: &mut ir::Fn, node_ref: ir::NodeRef) {
let node = &mut f.nodes[node_ref.index];
node.payload = NodePayload::Nil;
Expand Down Expand Up @@ -581,6 +702,7 @@ pub fn prep_for_gatify(
if options.enable_rewrite_prio_encode {
let _rewrites = rewrite_encode_one_hot_to_ext_prio_encode(&mut cloned);
}
let _rewrites = rewrite_projected_positive_mask_to_sign_test(&mut cloned);
mark_dead_nodes_as_nil(&mut cloned);
cloned
}
Expand Down Expand Up @@ -717,4 +839,50 @@ top fn cone(p0: bits[9] id=1, p1: bits[9] id=2) -> bits[1] {
optimized_text
);
}

#[test]
fn projected_positive_mask_rewrite_collapses_to_sign_test_and_data_bit() {
let ir_text = "package sample

fn cone(x: bits[13] id=1) -> bits[1] {
z: bits[13] = literal(value=0, id=2)
le0: bits[1] = sle(x, z, id=3)
gt0: bits[1] = not(le0, id=4)
low: bits[12] = bit_slice(x, start=0, width=12, id=5)
m: bits[12] = sign_ext(gt0, new_bit_count=12, id=6)
gated: bits[12] = and(low, m, id=7)
ret out: bits[1] = bit_slice(gated, start=11, width=1, id=8)
}
";
let mut parser = Parser::new(ir_text);
let pkg = parser.parse_and_validate_package().unwrap();
let f = pkg.get_top_fn().unwrap();

let optimized = prep_for_gatify(f, None, PrepForGatifyOptions::default());
let optimized_text = optimized.to_string();
assert!(
optimized_text.contains("and(") && optimized_text.contains("start=12, width=1"),
"expected rewrite to project sign-bit guard; got:\n{}",
optimized_text
);
assert!(
!optimized_text.contains("sle(") || optimized_text.lines().all(|l| l.contains("Nil")),
"expected SLE cone to be removed from the live logic; got:\n{}",
optimized_text
);

// Equivalence sanity check over all representable values for bits[13].
for x in 0u64..(1u64 << 13) {
let args = [IrValue::make_ubits(13, x).unwrap()];
let got_orig = match eval_fn(f, &args) {
FnEvalResult::Success(s) => s.value.to_bool().unwrap(),
FnEvalResult::Failure(f) => panic!("unexpected eval failure: {:?}", f),
};
let got_opt = match eval_fn(&optimized, &args) {
FnEvalResult::Success(s) => s.value.to_bool().unwrap(),
FnEvalResult::Failure(f) => panic!("unexpected eval failure: {:?}", f),
};
assert_eq!(got_orig, got_opt, "mismatch at x={}", x);
}
}
}
Loading