Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,27 @@ CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
const noexcept
{
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
// writing final results to a given macro tile in C.
// In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1
// workgroup writing final results to a given macro tile in C.
int num_wgs_per_tile = 1;

// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
// If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per
// tile. We only need to check that dp_tiles is greater than zero since we know we have SK
// workgroups.
if(dp_tiles_ > 0)
{
num_wgs_per_tile = 2;
}
else
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
}
}

return std::max(num_wgs_per_tile, 1);
Expand Down
2 changes: 1 addition & 1 deletion test/ck_tile/gemm_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})

# Currently test_ck_tile_streamk_smoke is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")

include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};

EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1);
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
}

TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue)
Expand Down