-
Notifications
You must be signed in to change notification settings - Fork 188
[CK_TILE] Add universal gemm mem skip A/B LDS pipelines #2056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
4a2f735
b3a1e16
d26270d
ff3aa68
b4dc448
acf61f9
608fd3a
4baaa1b
e0e2955
a22186f
2b5300a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,8 +37,8 @@ struct GemmPipelineAgBgCrImplBase | |
const SrcBlockTile& src_block_tile, | ||
const ElementFunction& element_func) const | ||
{ | ||
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); | ||
store_tile(lds_tile_window, block_tile_tmp); | ||
const auto block_tile = tile_elementwise_in(element_func, src_block_tile); | ||
store_tile(lds_tile_window, block_tile); | ||
} | ||
|
||
template <typename DstBlockTile, typename SrcTileWindow> | ||
|
@@ -56,8 +56,11 @@ struct GemmPipelineAgBgCrImplBase | |
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); | ||
|
||
// TODO: LDS alignment should come from Policy! | ||
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple( | ||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); | ||
constexpr index_t a_lds_block_space_size_aligned = | ||
Problem::SkipALds | ||
? 0 | ||
: integer_least_multiple( | ||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); | ||
|
||
// B tile in LDS | ||
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>( | ||
|
@@ -68,8 +71,8 @@ struct GemmPipelineAgBgCrImplBase | |
return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); | ||
} | ||
|
||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr> | ||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, | ||
template <typename ADramBlockWindow, typename ALdsTensorView, typename ALdsLoadTileDistr> | ||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindow& a_dram_block_window, | ||
const ALdsTensorView& a_lds_block_view, | ||
const ALdsLoadTileDistr&) const | ||
{ | ||
|
@@ -79,11 +82,22 @@ struct GemmPipelineAgBgCrImplBase | |
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>; | ||
|
||
// A DRAM tile window for load | ||
auto a_copy_dram_window = | ||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
a_dram_block_window_tmp.get_window_origin(), | ||
Policy::template MakeADramTileDistribution<Problem>()); | ||
decltype(auto) a_copy_dram_window = [&] { | ||
if constexpr(Problem::SkipALds == false) | ||
{ | ||
return make_tile_window(a_dram_block_window.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
a_dram_block_window.get_window_origin(), | ||
Policy::template MakeADramTileDistribution<Problem>()); | ||
} | ||
else | ||
{ | ||
return make_tile_window(a_dram_block_window.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
a_dram_block_window.get_window_origin(), | ||
ALdsLoadTileDistr{}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might not necessarily be the optimal solution. I imagine that regardless of using LDS we should read global memory in with an optimal vectorized access pattern. Then if it's needed you would just adapt the data layout in registers. |
||
} | ||
}(); | ||
|
||
// A LDS tile window for store | ||
auto a_copy_lds_window = make_tile_window( | ||
|
@@ -100,8 +114,8 @@ struct GemmPipelineAgBgCrImplBase | |
std::move(a_lds_gemm_window)); | ||
} | ||
|
||
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr> | ||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, | ||
template <typename BDramBlockWindow, typename BLdsTensorView, typename BLdsLoadTileDistr> | ||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindow& b_dram_block_window, | ||
const BLdsTensorView& b_lds_block_view, | ||
const BLdsLoadTileDistr&) const | ||
{ | ||
|
@@ -110,11 +124,22 @@ struct GemmPipelineAgBgCrImplBase | |
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>; | ||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>; | ||
|
||
auto b_copy_dram_window = | ||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
b_dram_block_window_tmp.get_window_origin(), | ||
Policy::template MakeBDramTileDistribution<Problem>()); | ||
decltype(auto) b_copy_dram_window = [&] { | ||
if constexpr(Problem::SkipBLds == false) | ||
{ | ||
return make_tile_window(b_dram_block_window.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
b_dram_block_window.get_window_origin(), | ||
Policy::template MakeBDramTileDistribution<Problem>()); | ||
} | ||
else | ||
{ | ||
return make_tile_window(b_dram_block_window.get_bottom_tensor_view(), | ||
make_tuple(YPerTile{}, XPerTile{}), | ||
b_dram_block_window.get_window_origin(), | ||
BLdsLoadTileDistr{}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above for |
||
} | ||
}(); | ||
|
||
// TODO: Do we really need those two tile windows??? | ||
// They're exactly same... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be "to"? Because this sounds like the LDS is being skipped to go straight to universal GEMM, which doesn't sound quite right.
Is it maybe supposed to be "for" or "in" or "with"?
As in support's been added for skipping LDS when using universal GEMM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be "for" or "in".