diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index 86e2b6bd8..8be2cd4df 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -124,6 +124,15 @@ def argmax(inp, dim=None, keepdim=False, *, dtype=None): M = math.prod(shape[:dim]) K = inp.numel() // M // N + # The special case of `M=1` + if M == 1: + shape_list = list(shape) + shape_list[dim] = 1 + out_index = torch.zeros(shape_list, dtype=torch.int64, device=inp.device) + if not keepdim: + out_index = torch.squeeze(out_index, dim) + return out_index + inp = inp.contiguous() shape_list = list(shape) @@ -133,7 +142,7 @@ def argmax(inp, dim=None, keepdim=False, *, dtype=None): out_index = torch.squeeze(out_index, dim) grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(M, min(meta["BLOCK_M"], M)), K, ) with torch_device_fn.device(inp.device):