Skip to content

Commit 5d937f5

Browse files
committed
NVPTX: Add f16 SIMD intrinsics
1 parent a3beb09 commit 5d937f5

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

crates/core_arch/src/nvptx/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
1414
use crate::ffi::c_void;
1515

16+
mod packed;
17+
18+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
19+
pub use packed::*;
20+
1621
#[allow(improper_ctypes)]
1722
extern "C" {
1823
#[link_name = "llvm.nvvm.barrier0"]

crates/core_arch/src/nvptx/packed.rs

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//! NVPTX Packed data types (SIMD)
2+
//!
3+
//! Packed Data Types is what PTX calls SIMD types. See [PTX ISA (Packed Data Types)](https://docs.nvidia.com/cuda/parallel-thread-execution/#packed-data-types) for a full reference.
4+
5+
// Note: #[assert_instr] tests are not actually being run on nvptx due to being a `no_std` target incapable of running tests. Something like FileCheck would be appropriate for verifying the correct instruction is used.
6+
7+
use crate::intrinsics::simd::*;
8+
9+
#[allow(improper_ctypes)]
10+
extern "C" {
11+
#[link_name = "llvm.minimum.v2f16"]
12+
fn llvm_f16x2_min(a: f16x2, b: f16x2) -> f16x2;
13+
#[link_name = "llvm.maximum.v2f16"]
14+
fn llvm_f16x2_max(a: f16x2, b: f16x2) -> f16x2;
15+
}
16+
17+
types! {
18+
/// PTX-specific 32-bit wide floating point (f16 x 2) vector type
19+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
20+
pub struct f16x2(f16, f16);
21+
22+
}
23+
24+
/// Add two values
25+
///
26+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add>
27+
#[inline]
28+
#[cfg_attr(test, assert_instr(add.rn.f16x22))]
29+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
30+
pub unsafe fn f16x2_add(a: f16x2, b: f16x2) -> f16x2 {
31+
simd_add(a, b)
32+
}
33+
34+
/// Subtract two values
35+
///
36+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-sub>
37+
#[inline]
38+
#[cfg_attr(test, assert_instr(sub.rn.f16x2))]
39+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
40+
pub unsafe fn f16x2_sub(a: f16x2, b: f16x2) -> f16x2 {
41+
simd_sub(a, b)
42+
}
43+
44+
/// Multiply two values
45+
///
46+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-mul>
47+
#[inline]
48+
#[cfg_attr(test, assert_instr(mul.rn.f16x2))]
49+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
50+
pub unsafe fn f16x2_mul(a: f16x2, b: f16x2) -> f16x2 {
51+
simd_mul(a, b)
52+
}
53+
54+
/// Fused multiply-add
55+
///
56+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma>
57+
#[inline]
58+
#[cfg_attr(test, assert_instr(fma.rn.f16x2))]
59+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
60+
pub unsafe fn f16x2_fma(a: f16x2, b: f16x2, c: f16x2) -> f16x2 {
61+
simd_fma(a, b, c)
62+
}
63+
64+
/// Arithmetic negate
65+
///
66+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-neg>
67+
#[inline]
68+
#[cfg_attr(test, assert_instr(neg.f16x2))]
69+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
70+
pub unsafe fn f16x2_neg(a: f16x2) -> f16x2 {
71+
simd_neg(a)
72+
}
73+
74+
/// Find the minimum of two values
75+
///
76+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-min>
77+
#[inline]
78+
#[cfg_attr(test, assert_instr(min.NaN.f16x2))]
79+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
80+
pub unsafe fn f16x2_min(a: f16x2, b: f16x2) -> f16x2 {
81+
llvm_f16x2_min(a, b)
82+
}
83+
84+
/// Find the maximum of two values
85+
///
86+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-max>
87+
#[inline]
88+
#[cfg_attr(test, assert_instr(max.NaN.f16x2))]
89+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
90+
pub unsafe fn f16x2_max(a: f16x2, b: f16x2) -> f16x2 {
91+
llvm_f16x2_max(a, b)
92+
}

0 commit comments

Comments
 (0)