@@ -44,13 +44,14 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
4444 }
4545 }
4646
47- // Collect ALL SSA value IDs that can be referenced
47+ // Collect ALL IDs in the module (not just SSA values) so next_id doesn't collide.
48+ // This includes block labels, function defs, types, etc.
4849 let mut all_ssa_ids: HashSet < Word > = HashSet :: new ( ) ;
4950
5051 // Track which block each value is defined in
5152 let mut id_to_block: HashMap < Word , Word > = HashMap :: new ( ) ;
5253
53- // Add module-level constants first
54+ // Add module-level constants and types first
5455 for inst in & module. types_global_values {
5556 if let Some ( id) = inst. result_id {
5657 all_ssa_ids. insert ( id) ;
@@ -106,6 +107,11 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
106107 }
107108
108109 for ( func_idx, func) in module. functions . iter ( ) . enumerate ( ) {
110+ // Function def/end IDs
111+ if let Some ( id) = func. def . as_ref ( ) . and_then ( |d| d. result_id ) {
112+ all_ssa_ids. insert ( id) ;
113+ }
114+
109115 // Function parameters
110116 for param in & func. parameters {
111117 if let Some ( id) = param. result_id {
@@ -114,6 +120,11 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
114120 }
115121
116122 for ( block_idx, block) in func. blocks . iter ( ) . enumerate ( ) {
123+ // Block label IDs are part of the ID space
124+ if let Some ( label_id) = block. label . as_ref ( ) . and_then ( |l| l. result_id ) {
125+ all_ssa_ids. insert ( label_id) ;
126+ }
127+
117128 // Get block label for id_to_block tracking
118129 let block_label = block. label . as_ref ( ) . and_then ( |l| l. result_id ) ;
119130
@@ -248,6 +259,22 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
248259 }
249260 }
250261
262+ // Seed ExprType for all known IDs so type-guarded rules work correctly.
263+ // This sets the type domain (BoolType/IntType/FloatType) for each expression
264+ // based on its SPIR-V result type, enabling type-aware rewrite guards.
265+ for ( & id, & type_id) in & ctx. id_to_type {
266+ if ctx. id_to_term . contains_key ( & id) {
267+ let type_str = match type_classes. get ( & type_id) {
268+ Some ( TypeClass :: Bool ) => "(BoolType)" ,
269+ Some ( TypeClass :: Int ) => "(IntType)" ,
270+ Some ( TypeClass :: Float ) => "(FloatType)" ,
271+ _ => continue ,
272+ } ;
273+ let cmd = format ! ( "(set (ExprType id{}) {})" , id, type_str) ;
274+ let _ = egraph. parse_and_run_program ( None , & cmd) ;
275+ }
276+ }
277+
251278 // ==========================================================================
252279 // PRE: Represent branch value pairs as Gamma selections
253280 // ==========================================================================
@@ -747,6 +774,14 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
747774 . iter ( )
748775 . find ( |inst| inst. class . opcode == Op :: TypeBool )
749776 . and_then ( |inst| inst. result_id ) ;
777+ let float32_type = module
778+ . types_global_values
779+ . iter ( )
780+ . find ( |inst| {
781+ inst. class . opcode == Op :: TypeFloat
782+ && inst. operands . first ( ) == Some ( & rspirv:: dr:: Operand :: LiteralBit32 ( 32 ) )
783+ } )
784+ . and_then ( |inst| inst. result_id ) ;
750785
751786 // Only extract from IDs that are both:
752787 // 1. True roots (operands of side effects) - these are the outputs we need
@@ -777,7 +812,43 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
777812 if !results. is_empty ( ) {
778813 let result_str = format ! ( "{}" , results[ 0 ] ) ;
779814 if let Some ( term) = parse_extract_result ( & result_str) {
780- let result_type = ctx. id_to_type . get ( & id) . copied ( ) . unwrap_or ( 0 ) ;
815+ let mut result_type = ctx. id_to_type . get ( & id) . copied ( ) . unwrap_or ( 0 ) ;
816+
817+ // Query ExprType from the egraph to detect type domain changes.
818+ // If the extracted term changed type domain (e.g., int→bool via
819+ // Gamma simplification), correct the result_type to match.
820+ let expr_type_query = format ! ( "(extract (ExprType id{}))" , id) ;
821+ if let Ok ( type_results) = egraph. parse_and_run_program ( None , & expr_type_query) {
822+ if !type_results. is_empty ( ) {
823+ let type_str = format ! ( "{}" , type_results[ 0 ] ) ;
824+ let current_class = type_classes
825+ . get ( & result_type)
826+ . copied ( )
827+ . unwrap_or ( TypeClass :: Other ) ;
828+ let egraph_class = if type_str. contains ( "BoolType" ) {
829+ TypeClass :: Bool
830+ } else if type_str. contains ( "IntType" ) {
831+ TypeClass :: Int
832+ } else if type_str. contains ( "FloatType" ) {
833+ TypeClass :: Float
834+ } else {
835+ TypeClass :: Other
836+ } ;
837+ if egraph_class != TypeClass :: Other && egraph_class != current_class {
838+ // Type domain changed - select correct SPIR-V type
839+ let corrected = match egraph_class {
840+ TypeClass :: Bool => bool_type,
841+ TypeClass :: Int => int32_type,
842+ TypeClass :: Float => float32_type,
843+ TypeClass :: Other => None ,
844+ } ;
845+ if let Some ( ct) = corrected {
846+ result_type = ct;
847+ ctx. id_to_type . insert ( id, ct) ;
848+ }
849+ }
850+ }
851+ }
781852
782853 // Before parsing, ensure all inline constants in the term have IDs
783854 // If the ENTIRE term is just a constant (e.g., "(Const 84)"), use the
@@ -896,14 +967,21 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
896967 & ctx. id_to_type ,
897968 & type_classes,
898969 bool_type,
970+ int32_type,
971+ float32_type,
899972 ) ;
900973 if corrected_type != result_type {
901974 inst. result_type = Some ( corrected_type) ;
902975 ctx. id_to_type . insert ( id, corrected_type) ;
903976 }
904- // Also collect IDs from the generated instruction
905- collect_ids_from_instruction ( & inst, & mut used_ids) ;
906- optimized_instructions. insert ( id, inst) ;
977+ // Safety: if the instruction still has invalid types, skip optimization
978+ if !instruction_has_valid_types ( & inst, & ctx. id_to_type , & type_classes) {
979+ // Fall back to original instruction
980+ } else {
981+ // Also collect IDs from the generated instruction
982+ collect_ids_from_instruction ( & inst, & mut used_ids) ;
983+ optimized_instructions. insert ( id, inst) ;
984+ }
907985 } else {
908986 // If simple parsing fails, try to materialize nested expressions
909987 // This handles cases like (Mul (Const 4) (Add (Sym "id5") (Sym "id6")))
@@ -936,13 +1014,24 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
9361014 & ctx. id_to_type ,
9371015 & type_classes,
9381016 bool_type,
1017+ int32_type,
1018+ float32_type,
9391019 ) ;
9401020 if corrected_type != result_type {
9411021 inst. result_type = Some ( corrected_type) ;
9421022 ctx. id_to_type . insert ( id, corrected_type) ;
9431023 }
944- collect_ids_from_instruction ( & inst, & mut used_ids) ;
945- optimized_instructions. insert ( id, inst) ;
1024+ // Safety: if the instruction still has invalid types, skip
1025+ if !instruction_has_valid_types (
1026+ & inst,
1027+ & ctx. id_to_type ,
1028+ & type_classes,
1029+ ) {
1030+ // Fall through - don't apply this optimization
1031+ } else {
1032+ collect_ids_from_instruction ( & inst, & mut used_ids) ;
1033+ optimized_instructions. insert ( id, inst) ;
1034+ }
9461035 // Update id_map if the ID changed
9471036 if let Some ( old) = old_id {
9481037 if old != id {
@@ -2185,6 +2274,8 @@ fn infer_result_type(
21852274 id_to_type : & HashMap < Word , Word > ,
21862275 type_classes : & HashMap < Word , TypeClass > ,
21872276 bool_type : Option < Word > ,
2277+ int32_type : Option < Word > ,
2278+ float32_type : Option < Word > ,
21882279) -> Word {
21892280 let op = inst. class . opcode ;
21902281 let required = match required_result_type_class ( op) {
@@ -2209,26 +2300,85 @@ fn infer_result_type(
22092300 return bt;
22102301 }
22112302 }
2212- TypeClass :: Int | TypeClass :: Float => {
2213- // For arithmetic ops, infer type from first operand
2303+ TypeClass :: Int => {
2304+ // Try to infer from operands first
22142305 for operand in & inst. operands {
22152306 if let Some ( operand_id) = operand. id_ref_any ( ) {
22162307 if let Some ( & operand_type) = id_to_type. get ( & operand_id) {
2217- if let Some ( & operand_class) = type_classes. get ( & operand_type) {
2218- if operand_class == required {
2219- return operand_type;
2220- }
2308+ if type_classes. get ( & operand_type) == Some ( & TypeClass :: Int ) {
2309+ return operand_type;
22212310 }
22222311 }
22232312 }
22242313 }
2314+ // Fall back to module's int32 type
2315+ if let Some ( it) = int32_type {
2316+ return it;
2317+ }
2318+ }
2319+ TypeClass :: Float => {
2320+ // Try to infer from operands first
2321+ for operand in & inst. operands {
2322+ if let Some ( operand_id) = operand. id_ref_any ( ) {
2323+ if let Some ( & operand_type) = id_to_type. get ( & operand_id) {
2324+ if type_classes. get ( & operand_type) == Some ( & TypeClass :: Float ) {
2325+ return operand_type;
2326+ }
2327+ }
2328+ }
2329+ }
2330+ // Fall back to module's float32 type
2331+ if let Some ( ft) = float32_type {
2332+ return ft;
2333+ }
22252334 }
22262335 TypeClass :: Other => { }
22272336 }
22282337
22292338 original_result_type
22302339}
22312340
2341+ /// Check if an instruction has valid types for its opcode.
2342+ /// Returns true if types are compatible, false if there's a mismatch.
2343+ fn instruction_has_valid_types (
2344+ inst : & Instruction ,
2345+ id_to_type : & HashMap < Word , Word > ,
2346+ type_classes : & HashMap < Word , TypeClass > ,
2347+ ) -> bool {
2348+ let op = inst. class . opcode ;
2349+
2350+ // Check result type
2351+ if let ( Some ( required) , Some ( result_type) ) = ( required_result_type_class ( op) , inst. result_type )
2352+ {
2353+ let actual = type_classes
2354+ . get ( & result_type)
2355+ . copied ( )
2356+ . unwrap_or ( TypeClass :: Other ) ;
2357+ if actual != required && actual != TypeClass :: Other {
2358+ return false ;
2359+ }
2360+ }
2361+
2362+ // Check operand types for comparisons
2363+ if let Some ( required_op_class) = required_operand_type_class ( op) {
2364+ for operand in & inst. operands {
2365+ if let Some ( operand_id) = operand. id_ref_any ( ) {
2366+ if let Some ( & operand_type) = id_to_type. get ( & operand_id) {
2367+ let actual = type_classes
2368+ . get ( & operand_type)
2369+ . copied ( )
2370+ . unwrap_or ( TypeClass :: Other ) ;
2371+ if actual != required_op_class && actual != TypeClass :: Other {
2372+ return false ;
2373+ }
2374+ }
2375+ }
2376+ }
2377+ }
2378+
2379+ true
2380+ }
2381+
22322382/// Topological sort of binding IDs based on term dependencies.
22332383/// If term for idA contains a bare reference to idB (meaning B is also in id_to_term),
22342384/// then B must be bound before A.
@@ -3130,13 +3280,52 @@ mod tests {
31303280 & id_to_type,
31313281 & type_classes,
31323282 Some ( bool_type_id) ,
3283+ Some ( int_type_id) ,
3284+ None ,
31333285 ) ;
31343286 assert_eq ! (
31353287 corrected, int_type_id,
31363288 "IAdd with bool type should be corrected to int type"
31373289 ) ;
31383290 }
31393291
3292+ #[ test]
3293+ fn infer_result_type_falls_back_to_int32_type ( ) {
3294+ let mut type_classes = HashMap :: new ( ) ;
3295+ let id_to_type = HashMap :: new ( ) ; // empty - no operand types
3296+
3297+ let bool_type_id: Word = 1 ;
3298+ let int_type_id: Word = 2 ;
3299+
3300+ type_classes. insert ( bool_type_id, TypeClass :: Bool ) ;
3301+ type_classes. insert ( int_type_id, TypeClass :: Int ) ;
3302+
3303+ // IAdd instruction with bool result type and NO operand type info
3304+ let inst = Instruction :: new (
3305+ Op :: IAdd ,
3306+ Some ( bool_type_id) ,
3307+ Some ( 100 ) ,
3308+ vec ! [
3309+ rspirv:: dr:: Operand :: IdRef ( 10 ) ,
3310+ rspirv:: dr:: Operand :: IdRef ( 11 ) ,
3311+ ] ,
3312+ ) ;
3313+
3314+ let corrected = infer_result_type (
3315+ & inst,
3316+ bool_type_id,
3317+ & id_to_type,
3318+ & type_classes,
3319+ Some ( bool_type_id) ,
3320+ Some ( int_type_id) ,
3321+ None ,
3322+ ) ;
3323+ assert_eq ! (
3324+ corrected, int_type_id,
3325+ "IAdd should fall back to int32_type when operands have no type info"
3326+ ) ;
3327+ }
3328+
31403329 #[ test]
31413330 fn infer_result_type_corrects_comparison_to_bool ( ) {
31423331 let mut id_to_type = HashMap :: new ( ) ;
@@ -3167,6 +3356,8 @@ mod tests {
31673356 & id_to_type,
31683357 & type_classes,
31693358 Some ( bool_type_id) ,
3359+ Some ( int_type_id) ,
3360+ None ,
31703361 ) ;
31713362 assert_eq ! (
31723363 corrected, bool_type_id,
0 commit comments