Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ops of AoT #70

Merged
merged 7 commits into from
Nov 18, 2024
Merged

Conversation

junjihashimoto
Copy link
Collaborator

@junjihashimoto junjihashimoto commented Oct 21, 2024

This PR implements the forward pass of gpt-2 with AoT.

The backward pass has been disabled as it needs to be fixed to not use atomicAdd.

There was a memory error and it was in draft state for a while.
I added -fsanitize=address to make it possible to detect errors.
The memory error has been fixed.

@junjihashimoto junjihashimoto changed the base branch from main to dev October 21, 2024 02:23
@junjihashimoto
Copy link
Collaborator Author

After creating the forward kernels, it outputs the following error. The backward kernels should be updated.

[kernels (feature/aot)]
$ time make
✓ OpenMP found
if [ ! -f gpt2_124M.bin ]; then ./llm.c/dev/download_starter_pack.sh ; \
          ln -s ./llm.c/gpt2_124M.bin ; \
          ln -s ./llm.c/gpt2_124M_debug_state.bin ; \
					ln -s ./llm.c/gpt2_tokenizer.bin ; \
	fi
. /Users/junji.hashimoto/git/gpu.cpp/experimental/kernels/../../source && ./build/gpt2_webgpu_aot
Creating GPU context
Building GPT-2 model from checkpoint 'gpt2_124M.bin'
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
padded_vocab_size: 50304
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124475904
Model build complete
train dataset num_batches: 1192
val dataset num_batches: 128
Starting training
Step 0
num_activations: 73347840
Allocating 279.80 MB for activations
Creating Kernels
[error] Device uncaptured error: Error while parsing WGSL: :24:13 error: no matching call to 'atomicAdd(ptr<storage, f32, read_write>, f32)'

1 candidate function:
 • 'atomicAdd(ptr<S, atomic<T>, read_write>  ✗ , T  ✗ ) -> T' where:
      ✗  'T' is 'i32' or 'u32'
      ✓  'S' is 'workgroup' or 'storage'

            atomicAdd(&dwte[ix * C + i], d);
            ^^^^^^^^^


 - While validating [ShaderModuleDescriptor "kernel"]
 - While calling [Device].CreateShaderModule([ShaderModuleDescriptor "kernel"]).

libc++abi: terminating due to uncaught exception of type std::runtime_error: Device uncaptured exception.
/bin/sh: line 1: 57347 Abort trap: 6           ./build/gpt2_webgpu_aot
make: *** [run-native] Error 134

real	0m10.054s
user	0m9.719s
sys	0m0.253s

@junjihashimoto junjihashimoto marked this pull request as ready for review November 16, 2024 05:25
#define NUM_ACTIVATION_TENSORS 23
typedef struct {
Tensor encoded; // (B, T, C)
std::vector<Tensor> ln1; // (L, B, T, C)
Copy link
Contributor

@austinvhuang austinvhuang Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, could we use std::array instead with NUM_ACTIVATION_TENSORS statically allocated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be confusing, but NUM_ACTIVATION_TENSORS is the number of variables like encoded and ln1, not the size of the vector.

typedef struct {
Tensor wte; // (V, C)
Tensor wpe; // (maxT, C)
std::vector<Tensor> ln1w; // (L, C)
Copy link
Contributor

@austinvhuang austinvhuang Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use std::array here instead of vector?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible if we decide on a model.
The size depends on the number of layers.
Since gpt2 has variations with different number of layers, the size cannot be determined at compile time.

//printf("inputs[0] = %d\n", inputs[0]);
// encoder_forward(ctx, acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]
{
std::promise<void> promise;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to think about a good way to wrap the async state so we don't have to make it explicit every step, but that doesn't have to be addressed in this PR, can explore as a follow-up.

@@ -99,6 +99,10 @@ build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
mkdir -p build
$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp

build/gpt2_webgpu_aot: llm.c gpt2_124M.bin llm.c gpt2_webgpu_aot.cpp ops_aot.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually this will probably be the main gpt2_webgpu implementaiton, though this is fine for this PR

Copy link
Contributor

@austinvhuang austinvhuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - see comments below if we can make std::vector statically allocated std::arrays.

This is a great step! With the resource allocation overhead addressed AOT we can then iterate on optimizing kernel perf.

Besides that, we'll probably want to iterate a bit on some quality-of-life things like packaging ops to wrap promises/futures so they don't need to be written out everytime but for now this is a good start.

@junjihashimoto
Copy link
Collaborator Author

@austinvhuang Thank you for your review!
To use std::arrays, the number of layers needs to be fixed.
There seem to be four types of layer numbers: 12, 24, 36, and 48. If we limit it to these four types, we can use std::array.

@@ -47,7 +47,6 @@ typedef struct {

// the parameters of the model
#define NUM_PARAMETER_TENSORS 16
#define NUM_PARAMETER_LAYERS 12
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following review, NUM_PARAMETER_LAYERS has been replaced with num_layers.

@austinvhuang
Copy link
Contributor

@austinvhuang Thank you for your review! To use std::arrays, the number of layers needs to be fixed. There seem to be four types of layer numbers: 12, 24, 36, and 48. If we limit it to these four types, we can use std::array.

That makes sense, fine to leave it as vector for this PR, though in a future update may replace vector with a lighter dynamic allocation approach like unique_ptr.

We can go ahead and merge to dev. Thanks!

@junjihashimoto
Copy link
Collaborator Author

Thank you, too!

@austinvhuang austinvhuang merged commit 28c7062 into AnswerDotAI:dev Nov 18, 2024
@junjihashimoto junjihashimoto deleted the feature/aot branch November 18, 2024 23:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants