Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
get<0>(problem_shape_O).max_length = max(1, max_length_q);
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dO) = get<0>(dO);
get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O));
get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O)));
// offset ptr by the amount we add back in later
ptr_O -= max_length_q * get<0>(dO);
}
} else {
get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O));
}

auto tma_store_o = make_tma_copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);

#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1

using ElementOut = typename CollectiveEpilogue::ElementOut;
Expand Down
3 changes: 3 additions & 0 deletions csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized {
problem_shape_qk = problem_shape;
}

get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));

auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized {
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);

#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1

using ElementOut = typename CollectiveEpilogue::ElementOut;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
}

get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));

auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));

auto params_qk = CollectiveMmaQK::to_underlying_arguments(
Expand Down
5 changes: 5 additions & 0 deletions csrc/sm100/device/fmha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class FMHA {
dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params);

// No need to launch the kernel
if(grid.x == 0 || grid.y == 0 || grid.z == 0) {
return Status::kSuccess;
}

// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;

Expand Down
2 changes: 1 addition & 1 deletion csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler {

return Params {
num_blocks,
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
{ size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) },
hw_info
};
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/sm100/kernel/fmha_tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ struct PersistentTileScheduler {

return Params {
num_blocks,
{ num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
{ max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
hw_info
};
}
Expand Down
3 changes: 3 additions & 0 deletions tests/test_fmha_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window):


def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
close_tensor = torch.isclose(x.to(torch.float32), y.to(torch.float32), rtol=1e-5, atol=1e-5)
if close_tensor.all():
return
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
Expand Down