Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward register #423

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open

Backward register #423

wants to merge 22 commits into from

Conversation

StrongSpoon
Copy link
Collaborator

@StrongSpoon StrongSpoon commented Jan 16, 2025

PR Category

Operator

Type of Change

New Feature

Description

register backward functions as aten interfaces
implement threshold operator incidentally

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@StrongSpoon StrongSpoon force-pushed the bwd branch 2 times, most recently from 9f79739 to 01bee17 Compare February 6, 2025 09:26
@StrongSpoon StrongSpoon marked this pull request as ready for review February 11, 2025 02:04
save_invstd=None,
train=False,
eps=1e-05,
output_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last argument should be grad_input_mask.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the definition of native_batch_norm_backward in aten lib is like below and we keep it the same:
native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)

affine: tl.constexpr,
input_grad_mask: tl.constexpr,
weight_grad_mask: tl.constexpr,
bias_grad_mask: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backward kernel may need is_train arg also, to distinguish between train and non-train cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it for future work tho.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit complex. fix it later QAQ

running_var=None,
save_mean=None,
save_invstd=None,
train=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel should be able to handle train=True case.


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arg train is optional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

logging.debug("GEMS NATIVE DROPOUT FORWARD")
assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
device = input.device
input = input.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that we'll remove contiguous enforcement in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +119 to +120
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor this in TODOs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

mean = mean.contiguous()
rstd = rstd.contiguous()
weight = None if weight is None else weight.contiguous()
group_size = C // group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups),
BLOCK_HW_SIZE=triton.next_power_of_2(HW),
HxW,
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv(C, group)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we didn't handle we train=False correctly in the previous version. Let's fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants