Skip to content

[BUG] Result error when seq length or hidden dim is not power of 2 #12

@zhou9402

Description

@zhou9402
BS, HEAD, SEQLEN, DIM = 2, 32, **4096, 63**
**# if set SEQLEN, DIM not power of 2, assertion will failed** 

def get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16):
    q = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    k = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    v = (torch.empty((BS, HEAD, SEQLEN, DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    return q, k, v

def main():
    q,k,v = get_tensors(BS, HEAD, SEQLEN, DIM, dtype=torch.float16)
    baseline = self_attention(q, k, v, causal=is_causal, sm_scale=sm_scale)

    #out, _ = block_sparse_attention(q, k, v, base_blockmask, is_causal, sm_scale)
    out, _ = flash_attention_v2_cutlass(q, k, v, is_causal, sm_scale)

    assert torch.allclose(baseline, out, rtol=0, atol=1e-2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions