Skip to content

Example of Gather/Scatter from CrossEntropyLoss #2556

@kevinstephano

Description

@kevinstephano

🚀 The feature, motivation and pitch

The task is to fuse log_softmax+gather. Naoya said it depends on his resize function work. The idea being that the tensor output of log_softmax can be roughly [64, 128, 32768] which in float is ~1GB. It is expensive to re-read that tensor versus “gathering” it to a size of [64, 128, 1] which is a trivially sized tensor. There are some people working on index_select, gather , and scatter but they have been only allowed to fuse them as the first operation of fusion. The gather , in this instance, would be at the end of the fusion.

CrossEntropyLoss forward includes a log_softmax followed by a gather operation.

Is is notably used in NLP networks like Bert as seen here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1139-L1143

Code example:

import torch

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        return self.loss_fn(inputs, targets)


cp_model = torch.compile(MyMod())

inputs = [
    torch.randn(512, 32768, device='cuda', requires_grad=True),
    torch.randint(0, 32768, (512,), device='cuda'),
]

for _ in range(5):
    out = cp_model(*inputs)
    out.backward()

How to view graph?

$ AOT_FX_GRAPHS=1 python test.py 
====== Forward graph 0 ======
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[512, 32768], primals_2: i64[512]):
        # File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
        amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
        sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax);  primals_1 = amax = None
        exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
        sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True);  exp = None
        log: f32[512, 1] = torch.ops.aten.log.default(sum_1);  sum_1 = None
        sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log);  sub = log = None
        ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor);  scalar_tensor = None
        unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1);  where = None
        gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze);  unsqueeze = None
        squeeze: f32[512] = torch.ops.aten.squeeze.dim(gather, 1);  gather = None
        neg: f32[512] = torch.ops.aten.neg.default(squeeze);  squeeze = None
        scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
        where_1: f32[512] = torch.ops.aten.where.self(ne, neg, scalar_tensor_1);  neg = scalar_tensor_1 = None
        sum_2: i64[] = torch.ops.aten.sum.default(ne);  ne = None
        convert_element_type: f32[] = torch.ops.prims.convert_element_type.default(sum_2, torch.float32);  sum_2 = None
        sum_3: f32[] = torch.ops.aten.sum.default(where_1);  where_1 = None
        div: f32[] = torch.ops.aten.div.Tensor(sum_3, convert_element_type);  sum_3 = None
        return [div, primals_2, sub_1, convert_element_type]
        
====== Backward graph 0 ======
class GraphModule(torch.nn.Module):
    def forward(self, primals_2: i64[512], sub_1: f32[512, 32768], convert_element_type: f32[], tangents_1: f32[]):
        # File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
        div_1: f32[] = torch.ops.aten.div.Tensor(tangents_1, convert_element_type);  tangents_1 = convert_element_type = None
        unsqueeze_1: i64[512, 1] = torch.ops.aten.unsqueeze.default(primals_2, 1);  primals_2 = None
        ne_3: b8[512, 1] = torch.ops.aten.ne.Scalar(unsqueeze_1, -100)
        where_2: i64[512, 1] = torch.ops.aten.where.self(ne_3, unsqueeze_1, scalar_tensor);  unsqueeze_1 = scalar_tensor = None
        full_like: f32[512, 32768] = torch.ops.aten.full_like.default(sub_1, 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format)
        scatter: f32[512, 32768] = torch.ops.aten.scatter.value(full_like, 1, where_2, -1.0);  full_like = where_2 = None
        where_3: f32[512, 1] = torch.ops.aten.where.self(ne_3, div_1, scalar_tensor_1);  ne_3 = div_1 = scalar_tensor_1 = None
        mul: f32[512, 32768] = torch.ops.aten.mul.Tensor(scatter, where_3);  scatter = where_3 = None
        exp_1: f32[512, 32768] = torch.ops.aten.exp.default(sub_1);  sub_1 = None
        sum_4: f32[512, 1] = torch.ops.aten.sum.dim_IntList(mul, [1], True)
        mul_1: f32[512, 32768] = torch.ops.aten.mul.Tensor(exp_1, sum_4);  exp_1 = sum_4 = None
        sub_2: f32[512, 32768] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        return [sub_2, None]

This section in particular is the log_softmax + gather. This is from printing out torch.compile's graph.

        amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
        sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax);  primals_1 = amax = None
        exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
        sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True);  exp = None
        log: f32[512, 1] = torch.ops.aten.log.default(sum_1);  sum_1 = None
        sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log);  sub = log = None
        ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor);  scalar_tensor = None
        unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1);  where = None
        gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze);  unsqueeze = None

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions