Skip to content

Commit ab6bc5c

Browse files
committed
spv: minimal OpConstantFunctionPointerINTEL support.
1 parent 0706bce commit ab6bc5c

File tree

8 files changed

+141
-18
lines changed

8 files changed

+141
-18
lines changed

src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,13 @@ pub struct ConstDef {
590590

591591
#[derive(Clone, PartialEq, Eq, Hash)]
592592
pub enum ConstKind {
593+
// FIXME(eddyb) maybe merge these? however, their connection is somewhat
594+
// tenuous (being one of the LLVM-isms SPIR-V inherited, among other things),
595+
// there's still the need to rename "global variable" post-`Var`-refactor,
596+
// and last but not least, `PtrToFunc` needs `SPV_INTEL_function_pointers`,
597+
// an OpenCL-only extension Intel came up with for their own SPIR-V tooling.
593598
PtrToGlobalVar(GlobalVar),
599+
PtrToFunc(Func),
594600

595601
// HACK(eddyb) this is a fallback case that should become increasingly rare
596602
// (especially wrt recursive consts), `Rc` means it can't bloat `ConstDef`.
@@ -679,7 +685,7 @@ pub struct FuncDecl {
679685
pub def: DeclDef<FuncDefBody>,
680686
}
681687

682-
#[derive(Copy, Clone)]
688+
#[derive(Copy, Clone, PartialEq, Eq)]
683689
pub struct FuncParam {
684690
pub attrs: AttrSet,
685691

src/print/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3332,6 +3332,9 @@ impl Print for ConstDef {
33323332
&ConstKind::PtrToGlobalVar(gv) => {
33333333
pretty::Fragment::new(["&".into(), gv.print(printer)])
33343334
}
3335+
&ConstKind::PtrToFunc(func) => {
3336+
pretty::Fragment::new(["&".into(), func.print(printer)])
3337+
}
33353338
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
33363339
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
33373340
pretty::Fragment::new([
@@ -4098,7 +4101,7 @@ impl Print for FuncAt<'_, DataInst> {
40984101
}
40994102
}
41004103
}
4101-
ConstKind::PtrToGlobalVar(_) => {}
4104+
ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) => {}
41024105
}
41034106
}
41044107
None

src/qptr/lower.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ impl<'a> LowerFromSpvPtrs<'a> {
161161
[spv::Imm::Short(_, sc)] => sc,
162162
_ => unreachable!(),
163163
};
164+
165+
// HACK(eddyb) keep function pointers separate, perhaps eventually
166+
// adding an `OpTypeUntypedPointerKHR CodeSectionINTEL` equivalent
167+
// to SPIR-T itself (after `SPV_KHR_untyped_pointers` support).
168+
if sc == self.wk.CodeSectionINTEL {
169+
return None;
170+
}
171+
164172
let pointee = match type_and_const_inputs[..] {
165173
[TypeOrConst::Type(elem_type)] => elem_type,
166174
_ => unreachable!(),

src/spv/lift.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> {
144144
}
145145
let ct_def = &self.cx[ct];
146146
match ct_def.kind {
147-
ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => {
147+
ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => {
148148
self.visit_const_def(ct_def);
149149
self.globals.insert(global);
150150
}
@@ -1032,7 +1032,9 @@ impl LazyInst<'_, '_> {
10321032
};
10331033
(gv_decl.attrs, import)
10341034
}
1035-
ConstKind::SpvInst { .. } => (ct_def.attrs, None),
1035+
ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => {
1036+
(ct_def.attrs, None)
1037+
}
10361038

10371039
// Not inserted into `globals` while visiting.
10381040
ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(),
@@ -1153,6 +1155,13 @@ impl LazyInst<'_, '_> {
11531155
}
11541156
}
11551157

1158+
&ConstKind::PtrToFunc(func) => spv::InstWithIds {
1159+
without_ids: wk.OpConstantFunctionPointerINTEL.into(),
1160+
result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]),
1161+
result_id,
1162+
ids: [ids.funcs[&func].func_id].into_iter().collect(),
1163+
},
1164+
11561165
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
11571166
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
11581167
spv::InstWithIds {

src/spv/lower.rs

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ enum IdDef {
2424

2525
Func(Func),
2626

27+
// HACK(eddyb) despite `FuncBody` deferring ID resolution to allow forward
28+
// references *between* functions, function pointer *constants* need a `Func`
29+
// long before any `OpFunction`s, so they're pre-defined as dummy imports.
30+
FuncForwardRef(Func),
31+
2732
SpvExtInstImport(InternedStr),
2833
SpvDebugString(InternedStr),
2934
}
@@ -36,7 +41,7 @@ impl IdDef {
3641
IdDef::Type(_) => "a type".into(),
3742
IdDef::Const(_) => "a constant".into(),
3843

39-
IdDef::Func(_) => "a function".into(),
44+
IdDef::Func(_) | IdDef::FuncForwardRef(_) => "a function".into(),
4045

4146
IdDef::SpvExtInstImport(name) => {
4247
format!("`OpExtInstImport {:?}`", &cx[name])
@@ -113,6 +118,37 @@ impl Module {
113118
// HACK(eddyb) used to quickly check whether an `OpVariable` is global.
114119
let storage_class_function_imm = spv::Imm::Short(wk.StorageClass, wk.Function);
115120

121+
// HACK(eddyb) used as the `FuncDecl` for an `IdDef::FuncForwardRef`.
122+
let dummy_decl_for_func_forward_ref = FuncDecl {
123+
attrs: {
124+
let mut attrs = AttrSet::default();
125+
attrs.push_diag(
126+
&cx,
127+
Diag::err(["function ID used as forward reference but never defined".into()]),
128+
);
129+
attrs
130+
},
131+
// FIXME(eddyb) this gets simpler w/ disaggregation.
132+
ret_type: cx.intern(TypeKind::SpvInst {
133+
spv_inst: wk.OpTypeVoid.into(),
134+
type_and_const_inputs: [].into_iter().collect(),
135+
}),
136+
params: [].into_iter().collect(),
137+
def: DeclDef::Imported(Import::LinkName(cx.intern(""))),
138+
};
139+
// HACK(eddyb) no `PartialEq` on `FuncDecl`.
140+
let assert_is_dummy_decl_for_func_forward_ref = |decl: &FuncDecl| {
141+
let [expected, found] = [&dummy_decl_for_func_forward_ref, decl].map(
142+
|FuncDecl { attrs, ret_type, params, def }| {
143+
let DeclDef::Imported(import) = def else {
144+
unreachable!();
145+
};
146+
(attrs, ret_type, params, import)
147+
},
148+
);
149+
assert!(expected == found);
150+
};
151+
116152
let mut module = {
117153
let [magic, version, generator_magic, id_bound, reserved_inst_schema] = parser.header;
118154

@@ -582,6 +618,38 @@ impl Module {
582618
});
583619
id_defs.insert(id, IdDef::Type(ty));
584620

621+
Seq::TypeConstOrGlobalVar
622+
} else if opcode == wk.OpConstantFunctionPointerINTEL {
623+
use std::collections::hash_map::Entry;
624+
625+
let id = inst.result_id.unwrap();
626+
627+
let func_id = inst.ids[0];
628+
let func = match id_defs.entry(func_id) {
629+
Entry::Occupied(entry) => match entry.get() {
630+
&IdDef::FuncForwardRef(func) => Ok(func),
631+
id_def => Err(id_def.descr(&cx)),
632+
},
633+
Entry::Vacant(entry) => {
634+
let func =
635+
module.funcs.define(&cx, dummy_decl_for_func_forward_ref.clone());
636+
entry.insert(IdDef::FuncForwardRef(func));
637+
Ok(func)
638+
}
639+
}
640+
.map_err(|descr| {
641+
invalid(&format!(
642+
"unsupported use of {descr} as the `OpConstantFunctionPointerINTEL` operand"
643+
))
644+
})?;
645+
646+
let ct = cx.intern(ConstDef {
647+
attrs: mem::take(&mut attrs),
648+
ty: result_type.unwrap(),
649+
kind: ConstKind::PtrToFunc(func),
650+
});
651+
id_defs.insert(id, IdDef::Const(ct));
652+
585653
Seq::TypeConstOrGlobalVar
586654
} else if inst_category == spec::InstructionCategory::Const || opcode == wk.OpUndef {
587655
let id = inst.result_id.unwrap();
@@ -754,19 +822,40 @@ impl Module {
754822
})
755823
}
756824
};
825+
let decl = FuncDecl {
826+
attrs: mem::take(&mut attrs),
827+
ret_type: func_ret_type,
828+
params: func_type_param_types
829+
.map(|ty| FuncParam { attrs: AttrSet::default(), ty })
830+
.collect(),
831+
def,
832+
};
757833

758-
let func = module.funcs.define(
759-
&cx,
760-
FuncDecl {
761-
attrs: mem::take(&mut attrs),
762-
ret_type: func_ret_type,
763-
params: func_type_param_types
764-
.map(|ty| FuncParam { attrs: AttrSet::default(), ty })
765-
.collect(),
766-
def,
767-
},
768-
);
769-
id_defs.insert(func_id, IdDef::Func(func));
834+
let func = {
835+
use std::collections::hash_map::Entry;
836+
837+
match id_defs.entry(func_id) {
838+
Entry::Occupied(mut entry) => match entry.get() {
839+
&IdDef::FuncForwardRef(func) => {
840+
let decl_slot = &mut module.funcs[func];
841+
assert_is_dummy_decl_for_func_forward_ref(decl_slot);
842+
*decl_slot = decl;
843+
844+
entry.insert(IdDef::Func(func));
845+
Ok(func)
846+
}
847+
id_def => Err(id_def.descr(&cx)),
848+
},
849+
Entry::Vacant(entry) => {
850+
let func = module.funcs.define(&cx, decl);
851+
entry.insert(IdDef::Func(func));
852+
Ok(func)
853+
}
854+
}
855+
.map_err(|descr| {
856+
invalid(&format!("invalid redefinition of {descr} as a new function"))
857+
})?
858+
};
770859

771860
current_func_body = Some(FuncBody { func_id, func, insts: vec![] });
772861

@@ -1170,7 +1259,7 @@ impl Module {
11701259
"unsupported use of {} outside `OpExtInst`",
11711260
id_def.descr(&cx),
11721261
))),
1173-
None => local_id_defs
1262+
None | Some(IdDef::FuncForwardRef(_)) => local_id_defs
11741263
.get(&id)
11751264
.copied()
11761265
.ok_or_else(|| invalid(&format!("undefined ID %{id}",))),

src/spv/spec.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def_well_known! {
137137
OpConstantTrue,
138138
OpConstant,
139139
OpUndef,
140+
OpConstantFunctionPointerINTEL,
140141

141142
OpVariable,
142143

@@ -201,6 +202,8 @@ def_well_known! {
201202
HitAttributeKHR,
202203
RayPayloadKHR,
203204
CallableDataKHR,
205+
206+
CodeSectionINTEL,
204207
],
205208
decoration: u32 = [
206209
LinkageAttributes,

src/transform.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ impl InnerTransform for ConstDef {
468468
gv -> transformer.transform_global_var_use(*gv),
469469
} => ConstKind::PtrToGlobalVar(gv)),
470470

471+
ConstKind::PtrToFunc(func) => transform!({
472+
func -> transformer.transform_func_use(*func),
473+
} => ConstKind::PtrToFunc(func)),
474+
471475
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
472476
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
473477
Transformed::map_iter(

src/visit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ impl InnerVisit for ConstDef {
341341
visitor.visit_type_use(*ty);
342342
match kind {
343343
&ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv),
344+
&ConstKind::PtrToFunc(func) => visitor.visit_func_use(func),
344345
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
345346
let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
346347
for &ct in const_inputs {

0 commit comments

Comments
 (0)