Skip to content
30 changes: 30 additions & 0 deletions src/flag_gems/ops/index_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,21 @@ def index_put(inp, indices, values, accumulate=False):
logger.debug("GEMS INDEX PUT")

indices = list(indices)
if len(indices) == 1 and indices[0].dtype == torch.bool:
mask = indices[0]

if mask.device != inp.device:
mask = mask.to(inp.device)

indices = list(torch.where(mask))

K = indices[0].numel()

if values.numel() == 1:
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
elif values.numel() == K:
values = values.reshape((K,))

indices = [
index.to(inp.device) if index.device != inp.device else index
for index in indices
Expand All @@ -275,6 +290,21 @@ def index_put_(inp, indices, values, accumulate=False):
logger.debug("GEMS INDEX PUT_")

indices = list(indices)
if len(indices) == 1 and indices[0].dtype == torch.bool:
mask = indices[0]

if mask.device != inp.device:
mask = mask.to(inp.device)

indices = list(torch.where(mask))

K = indices[0].numel()

if values.numel() == 1:
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
elif values.numel() == K:
values = values.reshape((K,))

indices = [
index.to(inp.device) if index.device != inp.device else index
for index in indices
Expand Down
Loading