Skip to content

Commit 70af80b

Browse files
committed
Fix IEqual/INotEqual signedness check and add cooperative type support
The SPIR-V spec for IEqual/INotEqual only requires matching component width and component count, allowing operands of different signedness. The previous check required exact type ID match, rejecting valid SPIR-V like comparing signed i32 with unsigned u32. Also add cooperative matrix (KHR/NV) and cooperative vector NV type support to arithmetic and logical comparison rules, matching the C++ validator's behavior.
1 parent 45f5b4f commit 70af80b

File tree

3 files changed

+114
-16
lines changed

3 files changed

+114
-16
lines changed

rust/spirv-tools-core/src/validation/rules/arithmetics.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ impl ValidationRule for FloatArithmeticRule {
6868
continue;
6969
};
7070

71-
// Result type must be float scalar, vector, or cooperative matrix
71+
// Result type must be float scalar, vector, cooperative matrix, or cooperative vector NV
7272
if !resolver.is_float_scalar_or_vector(result_type_id, ctx.definitions)
7373
&& !resolver.is_float_cooperative_matrix(result_type_id, ctx.definitions)
74+
&& !resolver.is_float_cooperative_vector_nv(result_type_id, ctx.definitions)
7475
{
7576
if let (Some(func), Some(block), Some(result_type)) = (
7677
function_id,
@@ -199,6 +200,12 @@ impl ValidationRule for IntArithmeticRule {
199200
if is_unsigned {
200201
if !resolver
201202
.is_unsigned_int_scalar_or_vector(result_type_id, ctx.definitions)
203+
&& !resolver
204+
.is_unsigned_int_cooperative_matrix(result_type_id, ctx.definitions)
205+
&& !resolver.is_unsigned_int_cooperative_vector_nv(
206+
result_type_id,
207+
ctx.definitions,
208+
)
202209
{
203210
if let (Some(func), Some(block), Some(result_type)) = (
204211
function_id,
@@ -215,7 +222,12 @@ impl ValidationRule for IntArithmeticRule {
215222
.into());
216223
}
217224
}
218-
} else if !resolver.is_int_scalar_or_vector(result_type_id, ctx.definitions) {
225+
} else if !resolver.is_int_scalar_or_vector(result_type_id, ctx.definitions)
226+
&& !resolver.is_int_cooperative_matrix(result_type_id, ctx.definitions)
227+
&& !resolver
228+
.is_unsigned_int_cooperative_matrix(result_type_id, ctx.definitions)
229+
&& !resolver.is_int_cooperative_vector_nv(result_type_id, ctx.definitions)
230+
{
219231
if let (Some(func), Some(block), Some(result_type)) = (
220232
function_id,
221233
block_id,

rust/spirv-tools-core/src/validation/rules/logicals.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ impl ValidationRule for FloatComparisonRule {
348348
let right_type = get_operand_type(1);
349349

350350
if let Some(left_tid) = left_type {
351-
if !resolver.is_float_scalar_or_vector(left_tid, ctx.definitions) {
351+
if !resolver.is_float_scalar_or_vector(left_tid, ctx.definitions)
352+
&& !resolver.is_float_cooperative_matrix(left_tid, ctx.definitions)
353+
&& !resolver.is_float_cooperative_vector_nv(left_tid, ctx.definitions)
354+
{
352355
if let (Some(func), Some(block), Some(result_type)) = (
353356
function_id,
354357
block_id,
@@ -368,8 +371,12 @@ impl ValidationRule for FloatComparisonRule {
368371
let left_dim = resolver.get_dimension(left_tid, ctx.definitions);
369372

370373
// Result must match operand dimension (scalar operand -> scalar result,
371-
// vector operand -> vector result)
372-
if left_dim != result_dim {
374+
// vector operand -> vector result). Skip dimension check for cooperative
375+
// matrix/vector types (they don't have a simple dimension).
376+
let is_cooperative = resolver
377+
.is_cooperative_matrix(left_tid, ctx.definitions)
378+
|| resolver.is_cooperative_vector_nv(left_tid, ctx.definitions);
379+
if !is_cooperative && left_dim != result_dim {
373380
if let (Some(func), Some(block), Some(result_type)) = (
374381
function_id,
375382
block_id,
@@ -614,8 +621,11 @@ impl ValidationRule for IntComparisonRule {
614621
let right_type = get_operand_type(1);
615622

616623
if let (Some(left_tid), Some(right_tid)) = (left_type, right_type) {
617-
// Left must be int
618-
if !resolver.is_int_scalar_or_vector(left_tid, ctx.definitions) {
624+
// Left must be int (scalar, vector, cooperative matrix, or cooperative vector NV)
625+
if !resolver.is_int_scalar_or_vector(left_tid, ctx.definitions)
626+
&& !resolver.is_int_cooperative_matrix(left_tid, ctx.definitions)
627+
&& !resolver.is_int_cooperative_vector_nv(left_tid, ctx.definitions)
628+
{
619629
if let (Some(func), Some(block), Some(result_type)) = (
620630
function_id,
621631
block_id,
@@ -632,8 +642,11 @@ impl ValidationRule for IntComparisonRule {
632642
}
633643
}
634644

635-
// Right must be int
636-
if !resolver.is_int_scalar_or_vector(right_tid, ctx.definitions) {
645+
// Right must be int (scalar, vector, cooperative matrix, or cooperative vector NV)
646+
if !resolver.is_int_scalar_or_vector(right_tid, ctx.definitions)
647+
&& !resolver.is_int_cooperative_matrix(right_tid, ctx.definitions)
648+
&& !resolver.is_int_cooperative_vector_nv(right_tid, ctx.definitions)
649+
{
637650
if let (Some(func), Some(block), Some(result_type)) = (
638651
function_id,
639652
block_id,
@@ -670,7 +683,7 @@ impl ValidationRule for IntComparisonRule {
670683
}
671684
}
672685

673-
// Check bit widths first (more specific error)
686+
// Check bit widths match
674687
let left_width = resolver.get_bit_width(left_tid, ctx.definitions);
675688
let right_width = resolver.get_bit_width(right_tid, ctx.definitions);
676689

@@ -690,8 +703,15 @@ impl ValidationRule for IntComparisonRule {
690703
}
691704
}
692705

693-
// Operand types must be identical (signedness matters even if width matches)
694-
if left_tid != right_tid {
706+
// For IEqual/INotEqual, the SPIR-V spec only requires
707+
// matching component width and component count (signedness
708+
// may differ). For all other integer comparisons
709+
// (UGreaterThan, SLessThan, etc.), operand types must be
710+
// identical.
711+
let is_equality_op =
712+
matches!(inst.class.opcode, Op::IEqual | Op::INotEqual);
713+
714+
if !is_equality_op && left_tid != right_tid {
695715
if let (Some(func), Some(block), Some(result_type)) = (
696716
function_id,
697717
block_id,

rust/spirv-tools-core/src/validation/tests/misc.rs

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,73 @@ fn compare_operands_must_match_each_other() {
598598
}
599599

600600
#[test]
601-
fn compare_operands_cannot_mix_signed_and_unsigned_ints() {
601+
fn iequal_allows_mixed_signedness() {
602+
// SPIR-V spec: IEqual/INotEqual only require matching component width,
603+
// not identical types. Signed and unsigned operands of the same width are allowed.
604+
use rspirv::{binary::Assemble, dr::Builder};
605+
let mut b = Builder::new();
606+
b.set_version(1, 6);
607+
b.capability(rspirv::spirv::Capability::Shader);
608+
b.capability(rspirv::spirv::Capability::Matrix);
609+
b.memory_model(
610+
rspirv::spirv::AddressingModel::Logical,
611+
rspirv::spirv::MemoryModel::GLSL450,
612+
);
613+
let void = b.type_void();
614+
let bool_ty = b.type_bool();
615+
let int = b.type_int(32, 1);
616+
let uint = b.type_int(32, 0);
617+
let fn_ty = b.type_function(void, std::iter::empty::<u32>());
618+
b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_ty)
619+
.unwrap();
620+
b.begin_block(None).unwrap();
621+
let lhs = b.constant_bit32(int, 1);
622+
let rhs = b.constant_bit32(uint, 1);
623+
b.i_equal(bool_ty, None, lhs, rhs).unwrap();
624+
b.ret().unwrap();
625+
b.end_function().unwrap();
626+
let words = b.module().assemble();
627+
words
628+
.as_slice()
629+
.validate(TargetEnv::Universal1_6)
630+
.expect("IEqual should accept operands with different signedness but same width");
631+
}
632+
633+
#[test]
634+
fn inotequal_allows_mixed_signedness() {
635+
// SPIR-V spec: INotEqual only requires matching component width.
636+
use rspirv::{binary::Assemble, dr::Builder};
637+
let mut b = Builder::new();
638+
b.set_version(1, 6);
639+
b.capability(rspirv::spirv::Capability::Shader);
640+
b.capability(rspirv::spirv::Capability::Matrix);
641+
b.memory_model(
642+
rspirv::spirv::AddressingModel::Logical,
643+
rspirv::spirv::MemoryModel::GLSL450,
644+
);
645+
let void = b.type_void();
646+
let bool_ty = b.type_bool();
647+
let int = b.type_int(32, 1);
648+
let uint = b.type_int(32, 0);
649+
let fn_ty = b.type_function(void, std::iter::empty::<u32>());
650+
b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_ty)
651+
.unwrap();
652+
b.begin_block(None).unwrap();
653+
let lhs = b.constant_bit32(int, 1);
654+
let rhs = b.constant_bit32(uint, 1);
655+
b.i_not_equal(bool_ty, None, lhs, rhs).unwrap();
656+
b.ret().unwrap();
657+
b.end_function().unwrap();
658+
let words = b.module().assemble();
659+
words
660+
.as_slice()
661+
.validate(TargetEnv::Universal1_6)
662+
.expect("INotEqual should accept operands with different signedness but same width");
663+
}
664+
665+
#[test]
666+
fn ugreater_than_rejects_mixed_signedness() {
667+
// UGreaterThan requires identical operand types (signedness matters).
602668
use rspirv::{binary::Assemble, dr::Builder};
603669
let mut b = Builder::new();
604670
b.set_version(1, 6);
@@ -619,20 +685,20 @@ fn compare_operands_cannot_mix_signed_and_unsigned_ints() {
619685
let header = b.begin_block(None).unwrap();
620686
let lhs = b.constant_bit32(int, 1);
621687
let rhs = b.constant_bit32(uint, 1);
622-
b.i_equal(bool_ty, None, lhs, rhs).unwrap();
688+
b.u_greater_than(bool_ty, None, lhs, rhs).unwrap();
623689
b.ret().unwrap();
624690
b.end_function().unwrap();
625691
let words = b.module().assemble();
626692
let err = words
627693
.as_slice()
628694
.validate(TargetEnv::Universal1_6)
629-
.expect_err("compare operands must have identical types, not just width");
695+
.expect_err("UGreaterThan requires identical operand types");
630696
assert_eq!(
631697
err,
632698
ValidationError::LogicalOperandTypeMismatch {
633699
function: Id::try_from(main).unwrap(),
634700
block: Id::try_from(header).unwrap(),
635-
opcode: rspirv::spirv::Op::IEqual,
701+
opcode: rspirv::spirv::Op::UGreaterThan,
636702
result_type: TypeId::try_from(bool_ty).unwrap(),
637703
}
638704
);

0 commit comments

Comments
 (0)