|
1 | 1 | // SPDX-License-Identifier: MIT |
2 | | -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
3 | 3 |
|
4 | 4 | #include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" |
5 | 5 | #include "ck/host/stringutils.hpp" |
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( |
76 | 76 | // Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| |
77 | 77 | // | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| |
78 | 78 | // | | | | | | | | | | | Wave| Wave| Wave| | |
79 | | - { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, |
80 | | - { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, |
81 | | - { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, |
82 | | - { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, |
83 | | - { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, |
84 | | - { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, |
85 | | - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
86 | | - { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
87 | | - { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, |
88 | | - { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, |
89 | | - { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, |
90 | | - { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, |
| 79 | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1}, |
| 80 | + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1}, |
| 81 | + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
| 82 | + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 83 | + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, |
| 84 | + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, |
| 85 | + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 86 | + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 87 | + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 88 | + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
| 89 | + { 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 90 | + { 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
91 | 91 | // Padded fallback kernel |
92 | | - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
93 | | - { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, |
| 92 | + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 93 | + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1}, |
94 | 94 | // Irregular k |
95 | | - { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, |
96 | | - { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, |
97 | | - { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, |
98 | | - { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, |
99 | | - { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, |
100 | | - { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, |
| 95 | + { 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1}, |
| 96 | + { 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1}, |
| 97 | + { 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1}, |
| 98 | + { 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1}, |
| 99 | + { 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1}, |
| 100 | + { 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1}, |
101 | 101 | // clang-format on |
102 | 102 | }; |
103 | 103 |
|
@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( |
200 | 200 | // _MBlock_MWaveMPerXdl| ScalarPerVector |
201 | 201 | // _NBlock_NWaveNPerXdl| _NWaveNPerXdl |
202 | 202 | // | |
203 | | - { S<1, 32, 1, 8>, 8}, |
204 | | - { S<1, 32, 1, 8>, 8}, |
205 | | - { S<1, 32, 1, 8>, 8}, |
206 | | - { S<1, 32, 1, 8>, 8}, |
207 | | - { S<1, 32, 1, 8>, 8}, |
208 | | - { S<1, 32, 1, 8>, 8}, |
209 | | - { S<1, 32, 1, 8>, 8}, |
210 | | - { S<1, 32, 1, 8>, 8}, |
211 | | - { S<1, 16, 1,16>, 8}, |
212 | | - { S<1, 32, 1, 8>, 8}, |
213 | | - { S<1, 16, 1,16>, 8}, |
214 | | - { S<1, 32, 1, 8>, 8}, |
| 203 | + { S<1, 32, 1, 8>, 4}, |
| 204 | + { S<1, 32, 1, 8>, 4}, |
| 205 | + { S<1, 32, 1, 8>, 4}, |
| 206 | + { S<1, 32, 1, 8>, 4}, |
| 207 | + { S<1, 32, 1, 8>, 4}, |
| 208 | + { S<1, 32, 1, 8>, 4}, |
| 209 | + { S<1, 32, 1, 8>, 4}, |
| 210 | + { S<1, 32, 1, 8>, 4}, |
| 211 | + { S<1, 16, 1,16>, 4}, |
| 212 | + { S<1, 32, 1, 8>, 4}, |
| 213 | + { S<1, 16, 1,16>, 4}, |
| 214 | + { S<1, 32, 1, 8>, 4}, |
215 | 215 | // Padded fallback kernel |
216 | | - { S<1, 32, 1, 8>, 8}, |
217 | | - { S<1, 32, 1, 8>, 8}, |
| 216 | + { S<1, 32, 1, 8>, 4}, |
| 217 | + { S<1, 32, 1, 8>, 4}, |
218 | 218 | // Irregular k |
219 | | - { S<1, 32, 1, 8>, 8}, |
220 | | - { S<1, 32, 1, 8>, 8}, |
221 | | - { S<1, 32, 1, 8>, 8}, |
222 | | - { S<1, 32, 1, 8>, 8}, |
223 | | - { S<1, 32, 1, 8>, 8}, |
224 | | - { S<1, 32, 1, 8>, 8}, |
| 219 | + { S<1, 32, 1, 8>, 4}, |
| 220 | + { S<1, 32, 1, 8>, 4}, |
| 221 | + { S<1, 32, 1, 8>, 4}, |
| 222 | + { S<1, 32, 1, 8>, 4}, |
| 223 | + { S<1, 32, 1, 8>, 4}, |
| 224 | + { S<1, 32, 1, 8>, 4}, |
225 | 225 | // clang-format on |
226 | 226 | }; |
227 | 227 |
|
|
0 commit comments