|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | + |
| 4 | +#include "ck/ck.hpp" |
| 5 | +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" |
| 6 | +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" |
| 7 | +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" |
| 8 | + |
| 9 | +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" |
| 10 | + |
| 11 | +namespace ck { |
| 12 | +namespace tensor_operation { |
| 13 | +namespace device { |
| 14 | +namespace instance { |
| 15 | + |
| 16 | +using F8 = f8_t; |
| 17 | +using F16 = half_t; |
| 18 | +using BF16 = bhalf_t; |
| 19 | +using F32 = float; |
| 20 | +using E8M0 = ck::e8m0_bexp_t; |
| 21 | + |
| 22 | +using Row = tensor_layout::gemm::RowMajor; |
| 23 | +using Col = tensor_layout::gemm::ColumnMajor; |
| 24 | + |
| 25 | +template <index_t... Is> |
| 26 | +using S = Sequence<Is...>; |
| 27 | + |
| 28 | +using PassThrough = element_wise::PassThrough; |
| 29 | + |
| 30 | +static constexpr auto GemmDefault = GemmSpecialization::Default; |
| 31 | +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; |
| 32 | +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; |
| 33 | +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; |
| 34 | + |
| 35 | +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; |
| 36 | +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; |
| 37 | + |
| 38 | +static constexpr auto ScaleBlockSize = 32; |
| 39 | + |
| 40 | +template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec> |
| 41 | +using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< |
| 42 | +// clang-format off |
| 43 | + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| |
| 44 | + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| |
| 45 | + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| |
| 46 | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| 47 | +#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) |
| 48 | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, |
| 49 | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, |
| 50 | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, |
| 51 | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, |
| 52 | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> |
| 53 | + |
| 54 | +//Require verification |
| 55 | + //DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> |
| 56 | +#endif |
| 57 | + // clang-format on |
| 58 | + >; |
| 59 | + |
| 60 | +} // namespace instance |
| 61 | +} // namespace device |
| 62 | +} // namespace tensor_operation |
| 63 | +} // namespace ck |
0 commit comments