Skip to content

wip MoE refactor #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

wip MoE refactor #2600

wants to merge 1 commit into from

Conversation

HDCharles
Copy link
Contributor

Summary:

now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before.

DONE
-implement moe with grouped_mm [x]
-add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x]

TODO
-add dispatch from grouped_mm to linear decomposition of quantized kernel
-compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition
-compare linear decomposition vs new linear decomposition for quantized kernels
-add scaled_group_gemm and fbgemm kernel (probably in a new PR)

ISSUE:
the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why?

see new_run.log output, first 2 runs are fine, line 144 is nonsense

Test Plan:

sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jul 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2600

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 11 New Failures, 1 Cancelled Job

As of commit d41d7b9 with merge base 0e00df3 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2025
@HDCharles HDCharles requested a review from alexsamardzic July 25, 2025 02:30
@alexsamardzic
Copy link
Collaborator

This one replaces #2325, right?

I'm struggling to run the run.sh script (i.e. generate.py script), keep getting "CUDA out of memory" errors on H100... Are you using PyTorch built from source and if not, which PyTorch package version do you have installed?

Would you mind finding mm_grouped.py file in your PyTorch installation, then changing can_use_triton_kernel() function there to just return False, and then re-trying? This will force eager (non-Triton) version of grouped MM kernel to be used even for max-autotune; namely, I suspect that the garbage output may not be from grouped MM Triton kernel itself, but maybe from max-autotuning the whole layer, and that would test it.

As a side note, it seems that MoEFeedForwardAOQuantizable should be imported for this and this.

@HDCharles
Copy link
Contributor Author

HDCharles commented Jul 26, 2025

This one replaces #2325, right?

I'm struggling to run the run.sh script (i.e. generate.py script), keep getting "CUDA out of memory" errors on H100... Are you using PyTorch built from source and if not, which PyTorch package version do you have installed?

Would you mind finding mm_grouped.py file in your PyTorch installation, then changing can_use_triton_kernel() function there to just return False, and then re-trying? This will force eager (non-Triton) version of grouped MM kernel to be used even for max-autotune; namely, I suspect that the garbage output may not be from grouped MM Triton kernel itself, but maybe from max-autotuning the whole layer, and that would test it.

As a side note, it seems that MoEFeedForwardAOQuantizable should be imported for this and this.

can you run it with batch_size 1?

i'll try the fix

yeah i haven't done the quantization dispatch stuff yet.

"""Configuration for applying quantization to MoE
Args:
`base_config`: normal AO Config
class DummyModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better solution is to make torchao APIs work on parameters. The current workaround is fine for prototype, but we'd want more proper support for non-prototype.

@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# T'(e) tokens for expert e


class MOEFeedForwardAOQuantizable(nn.Module):
class MoEFeedForwardAOQuantizable(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unlikely that people are going to swap their MoE module to AO's version. Can we just target torch._grouped_mm calls directly without requiring a module swap?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would it mean to "target" it specifically? If given model compiled, the compiled version of this operator will be used anyway, not sure what else torchao could do about it...

@alexsamardzic
Copy link
Collaborator

can you run it with batch_size 1?

Nope, with both batch_size 1 and 8, it runs out of memory.

Summary:

now that the pytorch grouped_mm kernels don't require padding,
refactoring the moe implementation to use that rather than what was
there before.

DONE
-implement moe with grouped_mm [x]
-add handling for generic module swap to AOQuantizable (MoEMapping) [x]
-refactor MoEQuantConfig to swap generic modules [x]

TODO
-add dispatch from grouped_mm to linear decomposition of quantized
kernel
-compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile
linear decomposition
-compare linear decomposition vs new linear decomposition for quantized
kernels
-add scaled_group_gemm and fbgemm kernel (probably in a new PR)

ISSUE:
the autotuned grouped_mm kernels don't give the correct output, but then
work in eager and compile with reduce-overhead. why?

see new_run.log output, first 2 runs are fine, line 144 is nonsense

Test Plan:

sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants