Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 153 additions & 84 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::symbols::Symbols;
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_abi::ExternAbi as Abi;
use rustc_abi::{AbiAlign, ExternAbi as Abi};
use rustc_abi::{
Align, BackendRepr, FieldIdx, FieldsShape, HasDataLayout as _, LayoutData, Primitive,
ReprFlags, ReprOptions, Scalar, Size, TagEncoding, VariantIdx, Variants,
Align, BackendRepr, FieldIdx, FieldsShape, LayoutData, Primitive, ReprFlags, ReprOptions,
Scalar, Size, TagEncoding, VariantIdx, Variants,
};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::ErrorGuaranteed;
use rustc_hashes::Hash64;
use rustc_index::Idx;
use rustc_middle::query::Providers;
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
use rustc_middle::ty::layout::{FnAbiOf, LayoutError, LayoutOf, TyAndLayout};
use rustc_middle::ty::{
self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
TyKind, UintTy,
self, AdtDef, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, GenericArgs, IntTy,
PolyFnSig, Ty, TyCtxt, TyKind, TypingEnv, UintTy,
};
use rustc_middle::ty::{GenericArgsRef, ScalarInt};
use rustc_middle::{bug, span_bug};
Expand Down Expand Up @@ -164,84 +165,19 @@ pub(crate) fn provide(providers: &mut Providers) {
}
}

providers.layout_of = |tcx, key| {
providers.layout_of = layout_of;

fn layout_of<'tcx>(
tcx: TyCtxt<'tcx>,
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
) -> Result<TyAndLayout<'tcx>, &'tcx LayoutError<'tcx>> {
// HACK(eddyb) to special-case any types at all, they must be normalized,
// but when normalization would be needed, `layout_of`'s default provider
// recurses (supposedly for caching reasons), i.e. its calls `layout_of`
// w/ the normalized type in input, which once again reaches this hook,
// without ever needing any explicit normalization here.
let ty = key.value;

// HACK(eddyb) bypassing upstream `#[repr(simd)]` changes (see also
// the later comment above `check_well_formed`, for more details).
let reimplement_old_style_repr_simd = match ty.kind() {
ty::Adt(def, args) if def.repr().simd() && !def.repr().packed() && def.is_struct() => {
Some(def.non_enum_variant()).and_then(|v| {
let (count, e_ty) = v
.fields
.iter()
.map(|f| f.ty(tcx, args))
.dedup_with_count()
.exactly_one()
.ok()?;
let e_len = u64::try_from(count).ok().filter(|&e_len| e_len > 1)?;
Some((def, e_ty, e_len))
})
}
_ => None,
};

// HACK(eddyb) tweaked copy of the old upstream logic for `#[repr(simd)]`:
// https://github.com/rust-lang/rust/blob/1.86.0/compiler/rustc_ty_utils/src/layout.rs#L464-L590
if let Some((adt_def, e_ty, e_len)) = reimplement_old_style_repr_simd {
let cx = rustc_middle::ty::layout::LayoutCx::new(
tcx,
key.typing_env.with_post_analysis_normalized(tcx),
);
let dl = cx.data_layout();

// Compute the ABI of the element type:
let e_ly = cx.layout_of(e_ty)?;
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
// This error isn't caught in typeck, e.g., if
// the element type of the vector is generic.
tcx.dcx().span_fatal(
tcx.def_span(adt_def.did()),
format!(
"SIMD type `{ty}` with a non-primitive-scalar \
(integer/float/pointer) element type `{}`",
e_ly.ty
),
);
};

// Compute the size and alignment of the vector:
let size = e_ly.size.checked_mul(e_len, dl).unwrap();
let align = dl.llvmlike_vector_align(size);
let size = size.align_to(align.abi);

let layout = tcx.mk_layout(LayoutData {
variants: Variants::Single {
index: rustc_abi::FIRST_VARIANT,
},
fields: FieldsShape::Array {
stride: e_ly.size,
count: e_len,
},
backend_repr: BackendRepr::SimdVector {
element: e_repr,
count: e_len,
},
largest_niche: e_ly.largest_niche,
uninhabited: false,
size,
align,
max_repr_align: None,
unadjusted_abi_align: align.abi,
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
});

return Ok(TyAndLayout { ty, layout });
if let Some(layout) = layout_of_spirv_attr_special(tcx, key)? {
return Ok(layout);
}

let TyAndLayout { ty, mut layout } =
Expand All @@ -268,7 +204,136 @@ pub(crate) fn provide(providers: &mut Providers) {
}

Ok(TyAndLayout { ty, layout })
};
}

fn layout_of_spirv_attr_special<'tcx>(
tcx: TyCtxt<'tcx>,
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
let ty::PseudoCanonicalInput {
typing_env,
value: ty,
} = key;

match ty.kind() {
ty::Adt(def, args) => {
let def: &AdtDef<'tcx> = def;
let args: &'tcx GenericArgs<'tcx> = args;
let attrs = AggregatedSpirvAttributes::parse(
tcx,
&Symbols::get(),
tcx.get_all_attrs(def.did()),
);

// add spirv-attr special layouts here
if let Some(layout) =
layout_of_spirv_vector(tcx, typing_env, ty, def, args, &attrs)?
{
return Ok(Some(layout));
}
}
_ => {}
}
Ok(None)
}

fn layout_of_spirv_vector<'tcx>(
tcx: TyCtxt<'tcx>,
typing_env: TypingEnv<'tcx>,
ty: Ty<'tcx>,
def: &AdtDef<'tcx>,
args: &'tcx GenericArgs<'tcx>,
attrs: &AggregatedSpirvAttributes,
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
let layout_err = |msg| {
&*tcx.arena.alloc(LayoutError::ReferencesError(
tcx.dcx().span_err(tcx.def_span(def.did()), msg),
))
};

let has_spirv_vector_attr = attrs
.intrinsic_type
.as_ref()
.map_or(false, |attr| matches!(attr.value, IntrinsicType::Vector));
let has_repr_simd = def.repr().simd() && !def.repr().packed();
if !has_spirv_vector_attr && !has_repr_simd {
return Ok(None);
}

let elements = def
.non_enum_variant()
.fields
.iter()
.map(|f| f.ty(tcx, args))
.dedup_with_count()
.exactly_one()
.ok()
.and_then(|(count, e_ty)| {
u64::try_from(count)
.ok()
.filter(|&e_len| e_len >= 2)
.map(|e_len| (e_len, e_ty))
});
let (e_len, e_ty) = match elements {
None => {
return if has_repr_simd {
// core SIMD struct, not glam vector, don't do anything special
Ok(None)
} else {
Err(layout_err(format!(
"spirv vector type `{ty}` must have at least 2 elements of a single element"
)))
};
}
Some(len) => len,
};
if !def.is_struct() {
return Err(layout_err(format!(
"spirv vector type `{ty}` must be a struct"
)));
}

let lcx = ty::layout::LayoutCx::new(tcx, typing_env.with_post_analysis_normalized(tcx));

// Compute the ABI of the element type:
let e_ly: TyAndLayout<'_> = lcx.layout_of(e_ty)?;
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
// This error isn't caught in typeck, e.g., if
// the element type of the vector is generic.
return Err(layout_err(format!(
"spirv vector type `{ty}` must have a non-primitive-scalar (integer/float/pointer) element type, got `{}`",
e_ly.ty
)));
};

// Compute the size and alignment of the vector:
let size = e_ly.size.checked_mul(e_len, &lcx).unwrap();
let align = def.repr().align.unwrap_or(e_ly.align.abi);
let size = size.align_to(align);

let layout = tcx.mk_layout(LayoutData {
variants: Variants::Single {
index: rustc_abi::FIRST_VARIANT,
},
fields: FieldsShape::Array {
stride: e_ly.size,
count: e_len,
},
backend_repr: BackendRepr::SimdVector {
element: e_repr,
count: e_len,
},
largest_niche: e_ly.largest_niche,
uninhabited: false,
size,
align: AbiAlign::new(align),
max_repr_align: None,
unadjusted_abi_align: align,
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
});

Ok(Some(TyAndLayout { ty, layout }))
}

// HACK(eddyb) work around https://github.com/rust-lang/rust/pull/129403
// banning "struct-style" `#[repr(simd)]` (in favor of "array-newtype-style"),
Expand Down Expand Up @@ -318,7 +383,7 @@ pub(crate) fn provide(providers: &mut Providers) {
let valid_non_array_simd_struct = trivial_struct.is_some_and(|adt_def| {
let ReprOptions {
int: None,
align: None,
align: _,
pack: None,
flags: ReprFlags::IS_SIMD,
field_shuffle_seed: _,
Expand Down Expand Up @@ -540,7 +605,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
span = cx.tcx.def_span(adt.did());
}

let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
let attrs =
AggregatedSpirvAttributes::parse(cx.tcx, &cx.sym, cx.tcx.get_all_attrs(adt.did()));

if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
&& let Ok(spirv_type) =
Expand Down Expand Up @@ -771,9 +837,9 @@ fn dig_scalar_pointee<'tcx>(
match pointee {
Some(old_pointee) if old_pointee != new_pointee => {
cx.tcx.dcx().fatal(format!(
"dig_scalar_pointee: unsupported Pointer with different \
"dig_scalar_pointee: unsupported Pointer with different \
pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
));
));
}
_ => pointee = Some(new_pointee),
}
Expand Down Expand Up @@ -1258,5 +1324,8 @@ fn trans_intrinsic_type<'tcx>(
}
.def(span, cx))
}
IntrinsicType::Vector => {
todo!()
}
}
}
30 changes: 17 additions & 13 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
//!
//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.

use crate::codegen_cx::CodegenCx;
use crate::symbols::Symbols;
use crate::symbols::{Symbols, parse_attrs_for_checking};
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
use rustc_hir as hir;
use rustc_hir::def_id::LocalModDefId;
Expand Down Expand Up @@ -66,6 +65,7 @@ pub enum IntrinsicType {
RuntimeArray,
TypedBuffer,
Matrix,
Vector,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -147,16 +147,20 @@ impl AggregatedSpirvAttributes {
///
/// Any errors for malformed/duplicate attributes will have been reported
/// prior to codegen, by the `attr` check pass.
pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
pub fn parse<'tcx>(
tcx: TyCtxt<'tcx>,
sym: &Symbols,
attrs: impl Iterator<Item = &'tcx Attribute>,
) -> Self {
let mut aggregated_attrs = Self::default();

// NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
// to see an attribute error, it will cause an ICE instead.
for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
for parse_attr_result in parse_attrs_for_checking(sym, attrs) {
let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
cx.tcx.dcx().span_delayed_bug(span, msg);
tcx.dcx().span_delayed_bug(span, msg);
continue;
}
};
Expand All @@ -166,8 +170,7 @@ impl AggregatedSpirvAttributes {
prev_span: _,
category,
}) => {
cx.tcx
.dcx()
tcx.dcx()
.span_delayed_bug(span, format!("multiple {category} attributes"));
}
}
Expand Down Expand Up @@ -278,10 +281,8 @@ impl CheckSpirvAttrVisitor<'_> {
fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
let mut aggregated_attrs = AggregatedSpirvAttributes::default();

let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);

let attrs = self.tcx.hir_attrs(hir_id);
for parse_attr_result in parse_attrs(attrs) {
for parse_attr_result in parse_attrs_for_checking(&self.sym, attrs.into_iter()) {
let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
Expand Down Expand Up @@ -326,9 +327,12 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::SpecConstant(_) => match target {
Target::Param => {
let parent_hir_id = self.tcx.parent_hir_id(hir_id);
let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
.filter_map(|r| r.ok())
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
let parent_is_entry_point = parse_attrs_for_checking(
&self.sym,
self.tcx.hir_attrs(parent_hir_id).iter(),
)
.filter_map(|r| r.ok())
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
if !parent_is_entry_point {
self.tcx.dcx().span_err(
span,
Expand Down
3 changes: 2 additions & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ impl<'tcx> CodegenCx<'tcx> {
self.set_linkage(fn_id, symbol_name.to_owned(), linkage);
}

let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.get_attrs_unchecked(def_id));
let attrs =
AggregatedSpirvAttributes::parse(self.tcx, &self.sym, self.tcx.get_all_attrs(def_id));
if let Some(entry) = attrs.entry.map(|attr| attr.value) {
// HACK(eddyb) early insert to let `shader_entry_stub` call this
// very function via `get_fn_addr`.
Expand Down
Loading
Loading