Skip to content

Commit 7ef40b0

Browse files
Add a flag to disable bardward-pass
1 parent 4985930 commit 7ef40b0

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

experimental/kernels/gpt2_webgpu_aot.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ typedef struct {
278278

279279
// kernels
280280
Kernels kernels;
281+
bool backward_enabled;
281282
} GPT2;
282283

283284
void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoint_path) {
@@ -379,6 +380,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin
379380
// Allocate B * C buffer for mean loss
380381
model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len);
381382
model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp);
383+
model->backward_enabled = false;
382384

383385
printf("Model build complete\n");
384386

@@ -476,20 +478,24 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
476478
kernels.crossentropy_forward = crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp);
477479

478480
kernels.encoder_forward = encoder_forward(ctx, model->acts.encoded, inputs, model->params.wte, model->params.wpe, B, T, C); // encoding goes into residual[0]
479-
kernels.encoder_backward = encoder_backward(ctx, model->params.wte, model->params.wpe, model->acts.encoded, inputs, B, T, C);
481+
if(model->backward_enabled)
482+
kernels.encoder_backward = encoder_backward(ctx, model->params.wte, model->params.wpe, model->acts.encoded, inputs, B, T, C);
480483
kernels.layernorm_final_forward = layernorm_forward(ctx, model->acts.lnf, model->acts.lnf_mean, model->acts.lnf_rstd,
481484
/*input=*/ model->acts.residual3[L-1], /*weight=*/ model->params.lnfw, /*bias=*/ model->params.lnfb,
482485
B, T, C);
483486
Tensor nullTensor = createTensor(ctx, Shape{1}, kf32);
484487
model->nullTensor = nullTensor;
485488
kernels.matmul_final_forward = matmul_forward(ctx, model->acts.logits, model->acts.lnf, model->params.wte, nullTensor, B, T, C, Vp);
486489
kernels.softmax_final_forward = softmax_forward(ctx, model->acts.probs, model->acts.logits, B, T, V, Vp);
487-
kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp);
488-
kernels.matmul_final_backward = matmul_backward(ctx, model->acts.lnf, model->params.wte, nullTensor, model->acts.logits,
489-
model->acts.lnf, model->params.wte, B, T, C, Vp);
490-
kernels.layernorm_final_backward = layernorm_backward(ctx, model->acts.residual3[L-1], model->params.lnfw, model->params.lnfb,
491-
model->acts.lnf, model->acts.residual3[L-1], model->params.lnfw,
492-
model->acts.lnf_mean, model->acts.lnf_rstd, B, T, C);
490+
if(model->backward_enabled)
491+
kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp);
492+
if(model->backward_enabled)
493+
kernels.matmul_final_backward = matmul_backward(ctx, model->acts.lnf, model->params.wte, nullTensor, model->acts.logits,
494+
model->acts.lnf, model->params.wte, B, T, C, Vp);
495+
if(model->backward_enabled)
496+
kernels.layernorm_final_backward = layernorm_backward(ctx, model->acts.residual3[L-1], model->params.lnfw, model->params.lnfb,
497+
model->acts.lnf, model->acts.residual3[L-1], model->params.lnfw,
498+
model->acts.lnf_mean, model->acts.lnf_rstd, B, T, C);
493499
printf("Created Kernels\n");
494500
}
495501

@@ -557,7 +563,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
557563
{
558564
std::promise<void> promise;
559565
std::future<void> future = promise.get_future();
560-
dispatchKernel(ctx, model->kernels.layernorm2_backward[l], promise);
566+
dispatchKernel(ctx, model->kernels.layernorm_forward[l], promise);
561567
wait(ctx, future);
562568
}
563569
printf(" [Forward] : FF Up\n");
@@ -1061,9 +1067,11 @@ int main() {
10611067
toGPU(ctx, train_loader.inputs, inputs);
10621068
toGPU(ctx, train_loader.targets, targets);
10631069
gpt2_forward(ctx, &model, inputs, targets, B, T);
1064-
gpt2_zero_grad(&model);
1065-
gpt2_backward(ctx, &model);
1066-
gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
1070+
if (model.backward_enabled) {
1071+
gpt2_zero_grad(&model);
1072+
gpt2_backward(ctx, &model);
1073+
gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
1074+
}
10671075
clock_gettime(CLOCK_MONOTONIC, &end);
10681076
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
10691077
printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000);

0 commit comments

Comments
 (0)