Skip to content

请教关于convert_layout_acc_rowcol #14

@hxdtest

Description

@hxdtest

请教下关于convert_layout_acc_rowcol函数。
注释中介绍 (MMA=4, MMA_M, MMA_N) ,对于GEMM-I计算,每个thread在寄存器中存放 4(MMA=4)个计算结果。
需要把这个计算结果转换为行列形式,以便于按照行索引执行softmax相关操作。 但是为什么 logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) 可以完成上述操作,以及 make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); 代码的含义 ?如果是其他形式,需要怎么处理?

Image


// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// TODO: 搞清楚经过convert_layout_acc_rowcol后(nrow=(2, MMA_M), ncol=(2, MMA_N))的数学含义
// 形象的解释是把
//    T1.V0
//    T1.V1
//    T1.V0
//    T1.V1
// 变为
//    T1.V0 T1.V1
//    T1.V0 T1.V1
// 这样符合MMA tile的行列直觉
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
    static_assert(decltype(size<0>(acc_layout))::value == 4);
    static_assert(decltype(rank(acc_layout))::value == 3);
    auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)
    // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
    // "int_tuple.hpp(74): error: conversion to inaccessible base class"
    // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
    return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
};

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions