Skip to content

Commit 8e86247

Browse files
committed
Seed and query ExprType in optimizer reconstruction
Seeds ExprType from SPIR-V type information before saturation so type-guarded rules fire correctly. After extraction, queries ExprType to detect when rewrites changed an expression's type domain and corrects the SPIR-V result type accordingly. Also adds a post- extraction type validation pass that drops instructions with mismatched result types. Updates tests to seed ExprType for BoolType when testing Gamma-to-LogAnd/LogOr conversions.
1 parent b86945d commit 8e86247

File tree

2 files changed

+209
-16
lines changed

2 files changed

+209
-16
lines changed

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 205 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
4444
}
4545
}
4646

47-
// Collect ALL SSA value IDs that can be referenced
47+
// Collect ALL IDs in the module (not just SSA values) so next_id doesn't collide.
48+
// This includes block labels, function defs, types, etc.
4849
let mut all_ssa_ids: HashSet<Word> = HashSet::new();
4950

5051
// Track which block each value is defined in
5152
let mut id_to_block: HashMap<Word, Word> = HashMap::new();
5253

53-
// Add module-level constants first
54+
// Add module-level constants and types first
5455
for inst in &module.types_global_values {
5556
if let Some(id) = inst.result_id {
5657
all_ssa_ids.insert(id);
@@ -106,6 +107,11 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
106107
}
107108

108109
for (func_idx, func) in module.functions.iter().enumerate() {
110+
// Function def/end IDs
111+
if let Some(id) = func.def.as_ref().and_then(|d| d.result_id) {
112+
all_ssa_ids.insert(id);
113+
}
114+
109115
// Function parameters
110116
for param in &func.parameters {
111117
if let Some(id) = param.result_id {
@@ -114,6 +120,11 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
114120
}
115121

116122
for (block_idx, block) in func.blocks.iter().enumerate() {
123+
// Block label IDs are part of the ID space
124+
if let Some(label_id) = block.label.as_ref().and_then(|l| l.result_id) {
125+
all_ssa_ids.insert(label_id);
126+
}
127+
117128
// Get block label for id_to_block tracking
118129
let block_label = block.label.as_ref().and_then(|l| l.result_id);
119130

@@ -248,6 +259,22 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
248259
}
249260
}
250261

262+
// Seed ExprType for all known IDs so type-guarded rules work correctly.
263+
// This sets the type domain (BoolType/IntType/FloatType) for each expression
264+
// based on its SPIR-V result type, enabling type-aware rewrite guards.
265+
for (&id, &type_id) in &ctx.id_to_type {
266+
if ctx.id_to_term.contains_key(&id) {
267+
let type_str = match type_classes.get(&type_id) {
268+
Some(TypeClass::Bool) => "(BoolType)",
269+
Some(TypeClass::Int) => "(IntType)",
270+
Some(TypeClass::Float) => "(FloatType)",
271+
_ => continue,
272+
};
273+
let cmd = format!("(set (ExprType id{}) {})", id, type_str);
274+
let _ = egraph.parse_and_run_program(None, &cmd);
275+
}
276+
}
277+
251278
// ==========================================================================
252279
// PRE: Represent branch value pairs as Gamma selections
253280
// ==========================================================================
@@ -747,6 +774,14 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
747774
.iter()
748775
.find(|inst| inst.class.opcode == Op::TypeBool)
749776
.and_then(|inst| inst.result_id);
777+
let float32_type = module
778+
.types_global_values
779+
.iter()
780+
.find(|inst| {
781+
inst.class.opcode == Op::TypeFloat
782+
&& inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(32))
783+
})
784+
.and_then(|inst| inst.result_id);
750785

751786
// Only extract from IDs that are both:
752787
// 1. True roots (operands of side effects) - these are the outputs we need
@@ -777,7 +812,43 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
777812
if !results.is_empty() {
778813
let result_str = format!("{}", results[0]);
779814
if let Some(term) = parse_extract_result(&result_str) {
780-
let result_type = ctx.id_to_type.get(&id).copied().unwrap_or(0);
815+
let mut result_type = ctx.id_to_type.get(&id).copied().unwrap_or(0);
816+
817+
// Query ExprType from the egraph to detect type domain changes.
818+
// If the extracted term changed type domain (e.g., int→bool via
819+
// Gamma simplification), correct the result_type to match.
820+
let expr_type_query = format!("(extract (ExprType id{}))", id);
821+
if let Ok(type_results) = egraph.parse_and_run_program(None, &expr_type_query) {
822+
if !type_results.is_empty() {
823+
let type_str = format!("{}", type_results[0]);
824+
let current_class = type_classes
825+
.get(&result_type)
826+
.copied()
827+
.unwrap_or(TypeClass::Other);
828+
let egraph_class = if type_str.contains("BoolType") {
829+
TypeClass::Bool
830+
} else if type_str.contains("IntType") {
831+
TypeClass::Int
832+
} else if type_str.contains("FloatType") {
833+
TypeClass::Float
834+
} else {
835+
TypeClass::Other
836+
};
837+
if egraph_class != TypeClass::Other && egraph_class != current_class {
838+
// Type domain changed - select correct SPIR-V type
839+
let corrected = match egraph_class {
840+
TypeClass::Bool => bool_type,
841+
TypeClass::Int => int32_type,
842+
TypeClass::Float => float32_type,
843+
TypeClass::Other => None,
844+
};
845+
if let Some(ct) = corrected {
846+
result_type = ct;
847+
ctx.id_to_type.insert(id, ct);
848+
}
849+
}
850+
}
851+
}
781852

782853
// Before parsing, ensure all inline constants in the term have IDs
783854
// If the ENTIRE term is just a constant (e.g., "(Const 84)"), use the
@@ -896,14 +967,21 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
896967
&ctx.id_to_type,
897968
&type_classes,
898969
bool_type,
970+
int32_type,
971+
float32_type,
899972
);
900973
if corrected_type != result_type {
901974
inst.result_type = Some(corrected_type);
902975
ctx.id_to_type.insert(id, corrected_type);
903976
}
904-
// Also collect IDs from the generated instruction
905-
collect_ids_from_instruction(&inst, &mut used_ids);
906-
optimized_instructions.insert(id, inst);
977+
// Safety: if the instruction still has invalid types, skip optimization
978+
if !instruction_has_valid_types(&inst, &ctx.id_to_type, &type_classes) {
979+
// Fall back to original instruction
980+
} else {
981+
// Also collect IDs from the generated instruction
982+
collect_ids_from_instruction(&inst, &mut used_ids);
983+
optimized_instructions.insert(id, inst);
984+
}
907985
} else {
908986
// If simple parsing fails, try to materialize nested expressions
909987
// This handles cases like (Mul (Const 4) (Add (Sym "id5") (Sym "id6")))
@@ -936,13 +1014,24 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
9361014
&ctx.id_to_type,
9371015
&type_classes,
9381016
bool_type,
1017+
int32_type,
1018+
float32_type,
9391019
);
9401020
if corrected_type != result_type {
9411021
inst.result_type = Some(corrected_type);
9421022
ctx.id_to_type.insert(id, corrected_type);
9431023
}
944-
collect_ids_from_instruction(&inst, &mut used_ids);
945-
optimized_instructions.insert(id, inst);
1024+
// Safety: if the instruction still has invalid types, skip
1025+
if !instruction_has_valid_types(
1026+
&inst,
1027+
&ctx.id_to_type,
1028+
&type_classes,
1029+
) {
1030+
// Fall through - don't apply this optimization
1031+
} else {
1032+
collect_ids_from_instruction(&inst, &mut used_ids);
1033+
optimized_instructions.insert(id, inst);
1034+
}
9461035
// Update id_map if the ID changed
9471036
if let Some(old) = old_id {
9481037
if old != id {
@@ -2185,6 +2274,8 @@ fn infer_result_type(
21852274
id_to_type: &HashMap<Word, Word>,
21862275
type_classes: &HashMap<Word, TypeClass>,
21872276
bool_type: Option<Word>,
2277+
int32_type: Option<Word>,
2278+
float32_type: Option<Word>,
21882279
) -> Word {
21892280
let op = inst.class.opcode;
21902281
let required = match required_result_type_class(op) {
@@ -2209,26 +2300,85 @@ fn infer_result_type(
22092300
return bt;
22102301
}
22112302
}
2212-
TypeClass::Int | TypeClass::Float => {
2213-
// For arithmetic ops, infer type from first operand
2303+
TypeClass::Int => {
2304+
// Try to infer from operands first
22142305
for operand in &inst.operands {
22152306
if let Some(operand_id) = operand.id_ref_any() {
22162307
if let Some(&operand_type) = id_to_type.get(&operand_id) {
2217-
if let Some(&operand_class) = type_classes.get(&operand_type) {
2218-
if operand_class == required {
2219-
return operand_type;
2220-
}
2308+
if type_classes.get(&operand_type) == Some(&TypeClass::Int) {
2309+
return operand_type;
22212310
}
22222311
}
22232312
}
22242313
}
2314+
// Fall back to module's int32 type
2315+
if let Some(it) = int32_type {
2316+
return it;
2317+
}
2318+
}
2319+
TypeClass::Float => {
2320+
// Try to infer from operands first
2321+
for operand in &inst.operands {
2322+
if let Some(operand_id) = operand.id_ref_any() {
2323+
if let Some(&operand_type) = id_to_type.get(&operand_id) {
2324+
if type_classes.get(&operand_type) == Some(&TypeClass::Float) {
2325+
return operand_type;
2326+
}
2327+
}
2328+
}
2329+
}
2330+
// Fall back to module's float32 type
2331+
if let Some(ft) = float32_type {
2332+
return ft;
2333+
}
22252334
}
22262335
TypeClass::Other => {}
22272336
}
22282337

22292338
original_result_type
22302339
}
22312340

2341+
/// Check if an instruction has valid types for its opcode.
2342+
/// Returns true if types are compatible, false if there's a mismatch.
2343+
fn instruction_has_valid_types(
2344+
inst: &Instruction,
2345+
id_to_type: &HashMap<Word, Word>,
2346+
type_classes: &HashMap<Word, TypeClass>,
2347+
) -> bool {
2348+
let op = inst.class.opcode;
2349+
2350+
// Check result type
2351+
if let (Some(required), Some(result_type)) = (required_result_type_class(op), inst.result_type)
2352+
{
2353+
let actual = type_classes
2354+
.get(&result_type)
2355+
.copied()
2356+
.unwrap_or(TypeClass::Other);
2357+
if actual != required && actual != TypeClass::Other {
2358+
return false;
2359+
}
2360+
}
2361+
2362+
// Check operand types for comparisons
2363+
if let Some(required_op_class) = required_operand_type_class(op) {
2364+
for operand in &inst.operands {
2365+
if let Some(operand_id) = operand.id_ref_any() {
2366+
if let Some(&operand_type) = id_to_type.get(&operand_id) {
2367+
let actual = type_classes
2368+
.get(&operand_type)
2369+
.copied()
2370+
.unwrap_or(TypeClass::Other);
2371+
if actual != required_op_class && actual != TypeClass::Other {
2372+
return false;
2373+
}
2374+
}
2375+
}
2376+
}
2377+
}
2378+
2379+
true
2380+
}
2381+
22322382
/// Topological sort of binding IDs based on term dependencies.
22332383
/// If term for idA contains a bare reference to idB (meaning B is also in id_to_term),
22342384
/// then B must be bound before A.
@@ -3130,13 +3280,52 @@ mod tests {
31303280
&id_to_type,
31313281
&type_classes,
31323282
Some(bool_type_id),
3283+
Some(int_type_id),
3284+
None,
31333285
);
31343286
assert_eq!(
31353287
corrected, int_type_id,
31363288
"IAdd with bool type should be corrected to int type"
31373289
);
31383290
}
31393291

3292+
#[test]
3293+
fn infer_result_type_falls_back_to_int32_type() {
3294+
let mut type_classes = HashMap::new();
3295+
let id_to_type = HashMap::new(); // empty - no operand types
3296+
3297+
let bool_type_id: Word = 1;
3298+
let int_type_id: Word = 2;
3299+
3300+
type_classes.insert(bool_type_id, TypeClass::Bool);
3301+
type_classes.insert(int_type_id, TypeClass::Int);
3302+
3303+
// IAdd instruction with bool result type and NO operand type info
3304+
let inst = Instruction::new(
3305+
Op::IAdd,
3306+
Some(bool_type_id),
3307+
Some(100),
3308+
vec![
3309+
rspirv::dr::Operand::IdRef(10),
3310+
rspirv::dr::Operand::IdRef(11),
3311+
],
3312+
);
3313+
3314+
let corrected = infer_result_type(
3315+
&inst,
3316+
bool_type_id,
3317+
&id_to_type,
3318+
&type_classes,
3319+
Some(bool_type_id),
3320+
Some(int_type_id),
3321+
None,
3322+
);
3323+
assert_eq!(
3324+
corrected, int_type_id,
3325+
"IAdd should fall back to int32_type when operands have no type info"
3326+
);
3327+
}
3328+
31403329
#[test]
31413330
fn infer_result_type_corrects_comparison_to_bool() {
31423331
let mut id_to_type = HashMap::new();
@@ -3167,6 +3356,8 @@ mod tests {
31673356
&id_to_type,
31683357
&type_classes,
31693358
Some(bool_type_id),
3359+
Some(int_type_id),
3360+
None,
31703361
);
31713362
assert_eq!(
31723363
corrected, bool_type_id,

rust/spirv-tools-opt/src/egglog_opt/tests.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7206,7 +7206,7 @@ fn test_boolconst_gamma_logand_type_safety() {
72067206

72077207
#[test]
72087208
fn test_boolconst_gamma_logand_allowed() {
7209-
// BoolConst(0) in false branch: SHOULD convert to LogAnd
7209+
// BoolConst(0) in false branch: SHOULD convert to LogAnd when x is bool-typed
72107210
let mut egraph = create_spirv_egraph().unwrap();
72117211

72127212
egraph
@@ -7215,6 +7215,7 @@ fn test_boolconst_gamma_logand_allowed() {
72157215
r#"
72167216
(let c (Sym "cond"))
72177217
(let x (Sym "x"))
7218+
(set (ExprType x) (BoolType))
72187219
(let root (Gamma c x (BoolConst 0)))
72197220
"#,
72207221
)
@@ -7226,7 +7227,7 @@ fn test_boolconst_gamma_logand_allowed() {
72267227
let check = egraph.parse_and_run_program(None, "(check (= root (LogAnd c x)))");
72277228
assert!(
72287229
check.is_ok(),
7229-
"Gamma(c, x, BoolConst(0)) should simplify to LogAnd(c, x)"
7230+
"Gamma(c, x, BoolConst(0)) should simplify to LogAnd(c, x) when x is BoolType"
72307231
);
72317232
}
72327233

@@ -7267,6 +7268,7 @@ fn test_boolconst_gamma_logor_allowed() {
72677268
r#"
72687269
(let c (Sym "cond"))
72697270
(let x (Sym "x"))
7271+
(set (ExprType x) (BoolType))
72707272
(let root (Gamma c (BoolConst 1) x))
72717273
"#,
72727274
)

0 commit comments

Comments
 (0)