-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Description
BS, HEAD, SEQLEN, DIM = 2, 32, **4096, 63**
**# if set SEQLEN, DIM not power of 2, assertion will failed**
def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16):
q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
return q, k, v
def main():
q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16)
baseline = self_attention(q, k, v, causal=is_causal, sm_scale=sm_scale)
#out, _ = block_sparse_attention(q, k, v, base_blockmask, is_causal, sm_scale)
out, _ = flash_attention_v2_cutlass(q, k, v, is_causal, sm_scale)
assert torch.allclose(baseline, out, rtol=0, atol=1e-2)
Metadata
Metadata
Assignees
Labels
No labels