Skip to content

Commit 2d72fc5

Browse files
committed
Separate boolean constants from integer constants in egglog optimizer
Introduce BoolConst as a distinct egglog constructor to prevent type confusion where Gamma(c, x, Const(0)) was incorrectly simplified to LogAnd(c, x) even in integer contexts. This caused SPIR-V validation failures (wrong result types on IAdd, FOrdLessThan operand errors) because OpLogicalAnd requires boolean operands/results. Key changes: - Add BoolConst(i64) constructor to datatypes.egg - Encode OpConstantTrue/False as BoolConst in context.rs - Parse BoolConst back to OpConstantTrue/False in constants.rs - Synthesize boolean constants with proper OpTypeBool in mod.rs - Update Gamma→LogAnd/LogOr rules to match only BoolConst branches - Update all comparison/logical rules to produce BoolConst results - Add BoolConst+Const dual matching for interop in logical ops - Add 14 tests verifying type-safe boolean optimization
1 parent 70af80b commit 2d72fc5

File tree

11 files changed

+782
-214
lines changed

11 files changed

+782
-214
lines changed

rust/spirv-tools-opt/src/direct/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ impl EgglogContext {
103103
format!("(Const {})", value)
104104
}
105105
}
106-
Op::ConstantTrue => "(Const 1)".to_string(),
107-
Op::ConstantFalse => "(Const 0)".to_string(),
106+
Op::ConstantTrue => "(BoolConst 1)".to_string(),
107+
Op::ConstantFalse => "(BoolConst 0)".to_string(),
108108

109109
Op::IAdd => self.binary_op("Add", inst)?,
110110
Op::ISub => self.binary_op("Sub", inst)?,

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,10 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
702702
}
703703
}
704704
Op::ConstantTrue => {
705-
id_map.entry("const_1".to_string()).or_insert(id);
705+
id_map.entry("boolconst_1".to_string()).or_insert(id);
706706
}
707707
Op::ConstantFalse => {
708-
id_map.entry("const_0".to_string()).or_insert(id);
708+
id_map.entry("boolconst_0".to_string()).or_insert(id);
709709
}
710710
_ => {}
711711
}
@@ -727,6 +727,11 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
727727
&& inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(32))
728728
})
729729
.and_then(|inst| inst.result_id);
730+
let bool_type = module
731+
.types_global_values
732+
.iter()
733+
.find(|inst| inst.class.opcode == Op::TypeBool)
734+
.and_then(|inst| inst.result_id);
730735

731736
// Only extract from IDs that are both:
732737
// 1. True roots (operands of side effects) - these are the outputs we need
@@ -763,24 +768,45 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
763768
// If the ENTIRE term is just a constant (e.g., "(Const 84)"), use the
764769
// current instruction's ID for that constant instead of synthesizing.
765770
// This enables proper DCE - the instruction becomes the constant.
766-
let is_root_const =
767-
term.trim().starts_with("(Const ") || term.trim().starts_with("(Const64 ");
768-
769-
for (is_64, value) in find_inline_constants(&term) {
770-
let key = if is_64 {
771-
format!("const64_{}", value)
772-
} else {
773-
format!("const_{}", value)
771+
use parse::InlineConstKind;
772+
let is_root_const = term.trim().starts_with("(Const ")
773+
|| term.trim().starts_with("(Const64 ")
774+
|| term.trim().starts_with("(BoolConst ");
775+
776+
for (kind, value) in find_inline_constants(&term) {
777+
let key = match kind {
778+
InlineConstKind::Int64 => format!("const64_{}", value),
779+
InlineConstKind::Int32 => format!("const_{}", value),
780+
InlineConstKind::Bool => format!("boolconst_{}", value),
774781
};
775782
if !id_map.contains_key(&key) {
776783
// If this root folds to a constant, use its ID for the constant
777784
// Don't synthesize a new constant - the instruction becomes it
778785
if is_root_const {
779786
id_map.insert(key, id);
780787
// Don't synthesize - will be added via folded_to_constant later
788+
} else if kind == InlineConstKind::Bool {
789+
// Synthesize a boolean constant
790+
if let Some(ty) = bool_type {
791+
let const_id = next_id;
792+
next_id += 1;
793+
let opcode = if value == 0 {
794+
Op::ConstantFalse
795+
} else {
796+
Op::ConstantTrue
797+
};
798+
synthesized_constants.push(Instruction::new(
799+
opcode,
800+
Some(ty),
801+
Some(const_id),
802+
vec![],
803+
));
804+
id_map.insert(key, const_id);
805+
id_map.insert(format!("id{}", const_id), const_id);
806+
}
781807
} else {
782-
// Create a new constant for use as an operand
783-
let const_type = if is_64 {
808+
// Create a new integer constant for use as an operand
809+
let const_type = if kind == InlineConstKind::Int64 {
784810
// Try to find 64-bit int type
785811
module
786812
.types_global_values
@@ -798,7 +824,7 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
798824
if let Some(ty) = const_type {
799825
let const_id = next_id;
800826
next_id += 1;
801-
let operand = if is_64 {
827+
let operand = if kind == InlineConstKind::Int64 {
802828
rspirv::dr::Operand::LiteralBit64(value as u64)
803829
} else {
804830
rspirv::dr::Operand::LiteralBit32(value as u32)

rust/spirv-tools-opt/src/direct/parse/constants.rs

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,29 @@ pub fn try_parse_constant(
6464
}
6565
}
6666

67+
// Parse (BoolConst N) - boolean constants
68+
if let Some(rest) = term.strip_prefix("(BoolConst ") {
69+
if let Some(num_str) = rest.strip_suffix(')') {
70+
if let Ok(value) = num_str.trim().parse::<i64>() {
71+
if value == 0 {
72+
return Some(Instruction::new(
73+
Op::ConstantFalse,
74+
Some(result_type),
75+
Some(result_id),
76+
vec![],
77+
));
78+
} else {
79+
return Some(Instruction::new(
80+
Op::ConstantTrue,
81+
Some(result_type),
82+
Some(result_id),
83+
vec![],
84+
));
85+
}
86+
}
87+
}
88+
}
89+
6790
// Parse (Sym "idN")
6891
if let Some(rest) = term.strip_prefix("(Sym \"") {
6992
if let Some(sym_name) = rest.strip_suffix("\")") {
@@ -81,34 +104,45 @@ pub fn try_parse_constant(
81104
None
82105
}
83106

84-
/// Find all (Const N) and (Const64 N) subterms in an extracted term.
85-
/// Returns a list of (is_64bit, value) tuples.
86-
pub fn find_inline_constants(term: &str) -> Vec<(bool, i64)> {
107+
/// Inline constant kind.
108+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109+
pub enum InlineConstKind {
110+
/// 32-bit integer constant
111+
Int32,
112+
/// 64-bit integer constant
113+
Int64,
114+
/// Boolean constant
115+
Bool,
116+
}
117+
118+
/// Find all (Const N), (Const64 N), and (BoolConst N) subterms in an extracted term.
119+
/// Returns a list of (kind, value) tuples.
120+
pub fn find_inline_constants(term: &str) -> Vec<(InlineConstKind, i64)> {
87121
let mut constants = Vec::new();
88122
let mut i = 0;
89123
let chars: Vec<char> = term.chars().collect();
90124

91125
while i < chars.len() {
92-
// Look for "(Const " or "(Const64 "
93-
if i + 7 <= chars.len() {
94-
let slice: String = chars[i..i + 7].iter().collect();
95-
if slice == "(Const " {
96-
// Find the closing paren
97-
let start = i + 7;
126+
// Look for "(BoolConst " (must check before "(Const " since it's longer)
127+
if i + 11 <= chars.len() {
128+
let slice: String = chars[i..i + 11].iter().collect();
129+
if slice == "(BoolConst " {
130+
let start = i + 11;
98131
let mut end = start;
99132
while end < chars.len() && chars[end] != ')' {
100133
end += 1;
101134
}
102135
if end < chars.len() {
103136
let num_str: String = chars[start..end].iter().collect();
104137
if let Ok(value) = num_str.trim().parse::<i64>() {
105-
constants.push((false, value));
138+
constants.push((InlineConstKind::Bool, value));
106139
}
107140
}
108141
i = end;
109142
continue;
110143
}
111144
}
145+
// Look for "(Const64 " (must check before "(Const " since it's longer)
112146
if i + 9 <= chars.len() {
113147
let slice: String = chars[i..i + 9].iter().collect();
114148
if slice == "(Const64 " {
@@ -120,7 +154,27 @@ pub fn find_inline_constants(term: &str) -> Vec<(bool, i64)> {
120154
if end < chars.len() {
121155
let num_str: String = chars[start..end].iter().collect();
122156
if let Ok(value) = num_str.trim().parse::<i64>() {
123-
constants.push((true, value));
157+
constants.push((InlineConstKind::Int64, value));
158+
}
159+
}
160+
i = end;
161+
continue;
162+
}
163+
}
164+
// Look for "(Const "
165+
if i + 7 <= chars.len() {
166+
let slice: String = chars[i..i + 7].iter().collect();
167+
if slice == "(Const " {
168+
// Find the closing paren
169+
let start = i + 7;
170+
let mut end = start;
171+
while end < chars.len() && chars[end] != ')' {
172+
end += 1;
173+
}
174+
if end < chars.len() {
175+
let num_str: String = chars[start..end].iter().collect();
176+
if let Ok(value) = num_str.trim().parse::<i64>() {
177+
constants.push((InlineConstKind::Int32, value));
124178
}
125179
}
126180
i = end;

rust/spirv-tools-opt/src/direct/parse/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use rspirv::spirv::Word;
3030
use std::collections::HashMap;
3131

3232
// Re-export public items
33-
pub use constants::find_inline_constants;
33+
pub use constants::{find_inline_constants, InlineConstKind};
3434
pub use extract::parse_extract_result;
3535

3636
/// Convert egglog term back to instruction.

0 commit comments

Comments
 (0)