-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Description
请教下关于convert_layout_acc_rowcol函数。
注释中介绍 (MMA=4, MMA_M, MMA_N) ,对于GEMM-I计算,每个thread在寄存器中存放 4(MMA=4)个计算结果。
需要把这个计算结果转换为行列形式,以便于按照行索引执行softmax相关操作。 但是为什么 logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) 可以完成上述操作,以及 make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); 代码的含义 ?如果是其他形式,需要怎么处理?
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// TODO: 搞清楚经过convert_layout_acc_rowcol后(nrow=(2, MMA_M), ncol=(2, MMA_N))的数学含义
// 形象的解释是把
// T1.V0
// T1.V1
// T1.V0
// T1.V1
// 变为
// T1.V0 T1.V1
// T1.V0 T1.V1
// 这样符合MMA tile的行列直觉
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
};
Metadata
Metadata
Assignees
Labels
No labels
