-
Notifications
You must be signed in to change notification settings - Fork 7
Description
🚀 The feature, motivation and pitch
The task is to fuse log_softmax+gather
. Naoya said it depends on his resize function work. The idea being that the tensor output of log_softmax can be roughly [64, 128, 32768]
which in float is ~1GB. It is expensive to re-read that tensor versus “gathering” it to a size of [64, 128, 1]
which is a trivially sized tensor. There are some people working on index_select
, gather
, and scatter
but they have been only allowed to fuse them as the first operation of fusion. The gather
, in this instance, would be at the end of the fusion.
CrossEntropyLoss forward
includes a log_softmax
followed by a gather
operation.
Is is notably used in NLP networks like Bert as seen here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1139-L1143
Code example:
import torch
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, inputs, targets):
return self.loss_fn(inputs, targets)
cp_model = torch.compile(MyMod())
inputs = [
torch.randn(512, 32768, device='cuda', requires_grad=True),
torch.randint(0, 32768, (512,), device='cuda'),
]
for _ in range(5):
out = cp_model(*inputs)
out.backward()
How to view graph?
$ AOT_FX_GRAPHS=1 python test.py
====== Forward graph 0 ======
class GraphModule(torch.nn.Module):
def forward(self, primals_1: f32[512, 32768], primals_2: i64[512]):
# File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax); primals_1 = amax = None
exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True); exp = None
log: f32[512, 1] = torch.ops.aten.log.default(sum_1); sum_1 = None
sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log); sub = log = None
ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor); scalar_tensor = None
unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1); where = None
gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze); unsqueeze = None
squeeze: f32[512] = torch.ops.aten.squeeze.dim(gather, 1); gather = None
neg: f32[512] = torch.ops.aten.neg.default(squeeze); squeeze = None
scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
where_1: f32[512] = torch.ops.aten.where.self(ne, neg, scalar_tensor_1); neg = scalar_tensor_1 = None
sum_2: i64[] = torch.ops.aten.sum.default(ne); ne = None
convert_element_type: f32[] = torch.ops.prims.convert_element_type.default(sum_2, torch.float32); sum_2 = None
sum_3: f32[] = torch.ops.aten.sum.default(where_1); where_1 = None
div: f32[] = torch.ops.aten.div.Tensor(sum_3, convert_element_type); sum_3 = None
return [div, primals_2, sub_1, convert_element_type]
====== Backward graph 0 ======
class GraphModule(torch.nn.Module):
def forward(self, primals_2: i64[512], sub_1: f32[512, 32768], convert_element_type: f32[], tangents_1: f32[]):
# File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
div_1: f32[] = torch.ops.aten.div.Tensor(tangents_1, convert_element_type); tangents_1 = convert_element_type = None
unsqueeze_1: i64[512, 1] = torch.ops.aten.unsqueeze.default(primals_2, 1); primals_2 = None
ne_3: b8[512, 1] = torch.ops.aten.ne.Scalar(unsqueeze_1, -100)
where_2: i64[512, 1] = torch.ops.aten.where.self(ne_3, unsqueeze_1, scalar_tensor); unsqueeze_1 = scalar_tensor = None
full_like: f32[512, 32768] = torch.ops.aten.full_like.default(sub_1, 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format)
scatter: f32[512, 32768] = torch.ops.aten.scatter.value(full_like, 1, where_2, -1.0); full_like = where_2 = None
where_3: f32[512, 1] = torch.ops.aten.where.self(ne_3, div_1, scalar_tensor_1); ne_3 = div_1 = scalar_tensor_1 = None
mul: f32[512, 32768] = torch.ops.aten.mul.Tensor(scatter, where_3); scatter = where_3 = None
exp_1: f32[512, 32768] = torch.ops.aten.exp.default(sub_1); sub_1 = None
sum_4: f32[512, 1] = torch.ops.aten.sum.dim_IntList(mul, [1], True)
mul_1: f32[512, 32768] = torch.ops.aten.mul.Tensor(exp_1, sum_4); exp_1 = sum_4 = None
sub_2: f32[512, 32768] = torch.ops.aten.sub.Tensor(mul, mul_1); mul = mul_1 = None
return [sub_2, None]
This section in particular is the log_softmax + gather
. This is from printing out torch.compile
's graph.
amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax); primals_1 = amax = None
exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True); exp = None
log: f32[512, 1] = torch.ops.aten.log.default(sum_1); sum_1 = None
sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log); sub = log = None
ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor); scalar_tensor = None
unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1); where = None
gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze); unsqueeze = None