-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backward register #423
base: master
Are you sure you want to change the base?
Backward register #423
Conversation
9f79739
to
01bee17
Compare
save_invstd=None, | ||
train=False, | ||
eps=1e-05, | ||
output_mask=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last argument should be grad_input_mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the definition of native_batch_norm_backward in aten lib is like below and we keep it the same:
native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
affine: tl.constexpr, | ||
input_grad_mask: tl.constexpr, | ||
weight_grad_mask: tl.constexpr, | ||
bias_grad_mask: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backward kernel may need is_train arg also, to distinguish between train and non-train cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave it for future work tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit complex. fix it later QAQ
running_var=None, | ||
save_mean=None, | ||
save_invstd=None, | ||
train=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel should be able to handle train=True case.
src/flag_gems/ops/dropout.py
Outdated
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arg train is optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
logging.debug("GEMS NATIVE DROPOUT FORWARD") | ||
assert p > 0.0 and p < 1.0, "p must be in (0, 1)" | ||
device = input.device | ||
input = input.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note that we'll remove contiguous enforcement in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
indices = indices.contiguous() | ||
weight = weight.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor this in TODOs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/flag_gems/ops/groupnorm.py
Outdated
mean = mean.contiguous() | ||
rstd = rstd.contiguous() | ||
weight = None if weight is None else weight.contiguous() | ||
group_size = C // group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
src/flag_gems/ops/groupnorm.py
Outdated
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups), | ||
BLOCK_HW_SIZE=triton.next_power_of_2(HW), | ||
HxW, | ||
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv(C, group)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
src/flag_gems/ops/dropout.py
Outdated
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized we didn't handle we train=False correctly in the previous version. Let's fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
PR Category
Operator
Type of Change
New Feature
Description
register backward functions as aten interfaces
implement threshold operator incidentally
Issue
Progress
Performance