Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for 4bit or 8bit tensor #5

Open
junjihashimoto opened this issue Jun 19, 2024 · 2 comments
Open

Support for 4bit or 8bit tensor #5

junjihashimoto opened this issue Jun 19, 2024 · 2 comments

Comments

@junjihashimoto
Copy link
Collaborator

Although it may be out of scope, it would be nice to have an example of computing 4bit and 8bit tensors, to save memory bandwidth.

@austinvhuang
Copy link
Contributor

I think this will come but there's a few pre-requisites. Initially the core library doesn't include specific shaders/kernels, I want to gradually build up to that. Only including them when they're clearly useful in multiple contexts so we don't end up with a large surface area of kernels to continually support. This is relevant because there's no built-in 4bit/8bit types so they would have to be closely coupled with various dequant implementations that would be part of the library.

A rough outline of the sequence might look like:

  • Add an example that builds up a gemm implementation step-by-step (from simple to optimized), in the spirit of https://siboehm.com/articles/22/CUDA-MMM
  • Add an example that builds up a transformer block computation
  • Port a full transformer model, maybe GPT2 from llm.c or one of the gemma.cpp models adding GPU support (paligemma would be an interesting choice because i think vision lends itself to the throughput advantages of GPUs) or something else of similar complexity
  • At this point we'll have several iterations of more mature shaders, at which point we can decide which ones to move into a core kernel library
  • After having some reusable kernels + an example model implementation, we can think about tensor types that are coupled to specific kernels (for dequant / fused dequant), and test them out in the context of a model.

@jgh-
Copy link
Collaborator

jgh- commented Jul 18, 2024

As of Chrome 123 there is support for dot4 accumulate for int8 vec4s https://developer.chrome.com/blog/new-in-webgpu-123 so for sure 8 bit makes sense to support to the extent that other precisions are supported by this library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants