From a7cd803d029d71ab4d111fca43ce33ba55fe9841 Mon Sep 17 00:00:00 2001
From: Erik Desjardins <erikdesjardins@users.noreply.github.com>
Date: Sun, 10 Mar 2024 22:38:53 -0400
Subject: [PATCH] use ptradd for vtable indexing

Like field offsets, these are always constant.
---
 compiler/rustc_codegen_ssa/src/base.rs | 13 ++--
 compiler/rustc_codegen_ssa/src/meth.rs | 19 +++---
 tests/codegen/dst-offset.rs            |  4 +-
 tests/codegen/vtable-upcast.rs         | 85 ++++++++++++++++++++++++++
 4 files changed, 104 insertions(+), 17 deletions(-)
 create mode 100644 tests/codegen/vtable-upcast.rs

diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs
index 5cba14a5ddabc..c316d19e04119 100644
--- a/compiler/rustc_codegen_ssa/src/base.rs
+++ b/compiler/rustc_codegen_ssa/src/base.rs
@@ -165,14 +165,11 @@ pub fn unsized_info<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
                 cx.tcx().vtable_trait_upcasting_coercion_new_vptr_slot((source, target));
 
             if let Some(entry_idx) = vptr_entry_idx {
-                let ptr_ty = cx.type_ptr();
-                let ptr_align = cx.tcx().data_layout.pointer_align.abi;
-                let gep = bx.inbounds_gep(
-                    ptr_ty,
-                    old_info,
-                    &[bx.const_usize(u64::try_from(entry_idx).unwrap())],
-                );
-                let new_vptr = bx.load(ptr_ty, gep, ptr_align);
+                let ptr_size = bx.data_layout().pointer_size;
+                let ptr_align = bx.data_layout().pointer_align.abi;
+                let vtable_byte_offset = u64::try_from(entry_idx).unwrap() * ptr_size.bytes();
+                let gep = bx.inbounds_ptradd(old_info, bx.const_usize(vtable_byte_offset));
+                let new_vptr = bx.load(bx.type_ptr(), gep, ptr_align);
                 bx.nonnull_metadata(new_vptr);
                 // VTable loads are invariant.
                 bx.set_invariant_load(new_vptr);
diff --git a/compiler/rustc_codegen_ssa/src/meth.rs b/compiler/rustc_codegen_ssa/src/meth.rs
index 12146a54d3b92..4f7dc9968a13c 100644
--- a/compiler/rustc_codegen_ssa/src/meth.rs
+++ b/compiler/rustc_codegen_ssa/src/meth.rs
@@ -20,9 +20,13 @@ impl<'a, 'tcx> VirtualIndex {
         ty: Ty<'tcx>,
         fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
     ) -> Bx::Value {
-        // Load the data pointer from the object.
+        // Load the function pointer from the object.
         debug!("get_fn({llvtable:?}, {ty:?}, {self:?})");
+
         let llty = bx.fn_ptr_backend_type(fn_abi);
+        let ptr_size = bx.data_layout().pointer_size;
+        let ptr_align = bx.data_layout().pointer_align.abi;
+        let vtable_byte_offset = self.0 * ptr_size.bytes();
 
         if bx.cx().sess().opts.unstable_opts.virtual_function_elimination
             && bx.cx().sess().lto() == Lto::Fat
@@ -30,12 +34,10 @@ impl<'a, 'tcx> VirtualIndex {
             let typeid = bx
                 .typeid_metadata(typeid_for_trait_ref(bx.tcx(), expect_dyn_trait_in_self(ty)))
                 .unwrap();
-            let vtable_byte_offset = self.0 * bx.data_layout().pointer_size.bytes();
             let func = bx.type_checked_load(llvtable, vtable_byte_offset, typeid);
             func
         } else {
-            let ptr_align = bx.tcx().data_layout.pointer_align.abi;
-            let gep = bx.inbounds_gep(llty, llvtable, &[bx.const_usize(self.0)]);
+            let gep = bx.inbounds_ptradd(llvtable, bx.const_usize(vtable_byte_offset));
             let ptr = bx.load(llty, gep, ptr_align);
             bx.nonnull_metadata(ptr);
             // VTable loads are invariant.
@@ -53,9 +55,12 @@ impl<'a, 'tcx> VirtualIndex {
         debug!("get_int({:?}, {:?})", llvtable, self);
 
         let llty = bx.type_isize();
-        let usize_align = bx.tcx().data_layout.pointer_align.abi;
-        let gep = bx.inbounds_gep(llty, llvtable, &[bx.const_usize(self.0)]);
-        let ptr = bx.load(llty, gep, usize_align);
+        let ptr_size = bx.data_layout().pointer_size;
+        let ptr_align = bx.data_layout().pointer_align.abi;
+        let vtable_byte_offset = self.0 * ptr_size.bytes();
+
+        let gep = bx.inbounds_ptradd(llvtable, bx.const_usize(vtable_byte_offset));
+        let ptr = bx.load(llty, gep, ptr_align);
         // VTable loads are invariant.
         bx.set_invariant_load(ptr);
         ptr
diff --git a/tests/codegen/dst-offset.rs b/tests/codegen/dst-offset.rs
index f0157e5a10646..ce735baeb6a95 100644
--- a/tests/codegen/dst-offset.rs
+++ b/tests/codegen/dst-offset.rs
@@ -25,9 +25,9 @@ struct Dst<T: ?Sized> {
 pub fn dst_dyn_trait_offset(s: &Dst<dyn Drop>) -> &dyn Drop {
 // The alignment of dyn trait is unknown, so we compute the offset based on align from the vtable.
 
-// CHECK: [[SIZE_PTR:%[0-9]+]] = getelementptr inbounds {{.+}} [[VTABLE_PTR]]
+// CHECK: [[SIZE_PTR:%[0-9]+]] = getelementptr inbounds i8, ptr [[VTABLE_PTR]]
 // CHECK: load [[USIZE]], ptr [[SIZE_PTR]]
-// CHECK: [[ALIGN_PTR:%[0-9]+]] = getelementptr inbounds {{.+}} [[VTABLE_PTR]]
+// CHECK: [[ALIGN_PTR:%[0-9]+]] = getelementptr inbounds i8, ptr [[VTABLE_PTR]]
 // CHECK: load [[USIZE]], ptr [[ALIGN_PTR]]
 
 // CHECK: getelementptr inbounds i8, ptr [[DATA_PTR]]
diff --git a/tests/codegen/vtable-upcast.rs b/tests/codegen/vtable-upcast.rs
new file mode 100644
index 0000000000000..41a4be26cb44e
--- /dev/null
+++ b/tests/codegen/vtable-upcast.rs
@@ -0,0 +1,85 @@
+//! This file tests that we correctly generate GEP instructions for vtable upcasting.
+//@ compile-flags: -C no-prepopulate-passes -Copt-level=0
+
+#![crate_type = "lib"]
+#![feature(trait_upcasting)]
+
+pub trait Base {
+    fn base(&self);
+}
+
+pub trait A : Base {
+    fn a(&self);
+}
+
+pub trait B : Base {
+    fn b(&self);
+}
+
+pub trait Diamond : A + B {
+    fn diamond(&self);
+}
+
+// CHECK-LABEL: upcast_a_to_base
+#[no_mangle]
+pub fn upcast_a_to_base(x: &dyn A) -> &dyn Base {
+    // Requires no adjustment, since its vtable is extended from `Base`.
+
+    // CHECK: start:
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: ret
+    x as &dyn Base
+}
+
+// CHECK-LABEL: upcast_b_to_base
+#[no_mangle]
+pub fn upcast_b_to_base(x: &dyn B) -> &dyn Base {
+    // Requires no adjustment, since its vtable is extended from `Base`.
+
+    // CHECK: start:
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: ret
+    x as &dyn Base
+}
+
+// CHECK-LABEL: upcast_diamond_to_a
+#[no_mangle]
+pub fn upcast_diamond_to_a(x: &dyn Diamond) -> &dyn A {
+    // Requires no adjustment, since its vtable is extended from `A` (as the first supertrait).
+
+    // CHECK: start:
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: ret
+    x as &dyn A
+}
+
+// CHECK-LABEL: upcast_diamond_to_b
+// CHECK-SAME: (ptr align {{[0-9]+}} [[DATA_PTR:%.+]], ptr align {{[0-9]+}} [[VTABLE_PTR:%.+]])
+#[no_mangle]
+pub fn upcast_diamond_to_b(x: &dyn Diamond) -> &dyn B {
+    // Requires adjustment, since it's a non-first supertrait.
+
+    // CHECK: start:
+    // CHECK-NEXT: [[UPCAST_SLOT_PTR:%.+]] = getelementptr inbounds i8, ptr [[VTABLE_PTR]]
+    // CHECK-NEXT: [[UPCAST_VTABLE_PTR:%.+]] = load ptr, ptr [[UPCAST_SLOT_PTR]]
+    // CHECK-NEXT: [[FAT_PTR_1:%.+]] = insertvalue { ptr, ptr } poison, ptr [[DATA_PTR]], 0
+    // CHECK-NEXT: [[FAT_PTR_2:%.+]] = insertvalue { ptr, ptr } [[FAT_PTR_1]], ptr [[UPCAST_VTABLE_PTR]], 1
+    // CHECK-NEXT: ret { ptr, ptr } [[FAT_PTR_2]]
+    x as &dyn B
+}
+
+// CHECK-LABEL: upcast_diamond_to_b
+#[no_mangle]
+pub fn upcast_diamond_to_base(x: &dyn Diamond) -> &dyn Base {
+    // Requires no adjustment, since `Base` is the first supertrait of `A`,
+    // which is the first supertrait of `Diamond`.
+
+    // CHECK: start:
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: insertvalue
+    // CHECK-NEXT: ret
+    x as &dyn Base
+}