@@ -278,6 +278,7 @@ typedef struct {
278
278
279
279
// kernels
280
280
Kernels kernels;
281
+ bool backward_enabled;
281
282
} GPT2;
282
283
283
284
void gpt2_build_from_checkpoint (Context& ctx, GPT2 *model, const char * checkpoint_path) {
@@ -379,6 +380,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin
379
380
// Allocate B * C buffer for mean loss
380
381
model->mean_loss_buffer = (float *)mallocCheck (sizeof (float ) * model->batch_size * model->seq_len );
381
382
model->probs_buffer = (float *)mallocCheck (sizeof (float ) * model->batch_size * model->seq_len * Vp);
383
+ model->backward_enabled = false ;
382
384
383
385
printf (" Model build complete\n " );
384
386
@@ -476,20 +478,24 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
476
478
kernels.crossentropy_forward = crossentropy_forward (ctx, model->acts .losses , model->acts .probs , targets, B, T, Vp);
477
479
478
480
kernels.encoder_forward = encoder_forward (ctx, model->acts .encoded , inputs, model->params .wte , model->params .wpe , B, T, C); // encoding goes into residual[0]
479
- kernels.encoder_backward = encoder_backward (ctx, model->params .wte , model->params .wpe , model->acts .encoded , inputs, B, T, C);
481
+ if (model->backward_enabled )
482
+ kernels.encoder_backward = encoder_backward (ctx, model->params .wte , model->params .wpe , model->acts .encoded , inputs, B, T, C);
480
483
kernels.layernorm_final_forward = layernorm_forward (ctx, model->acts .lnf , model->acts .lnf_mean , model->acts .lnf_rstd ,
481
484
/* input=*/ model->acts .residual3 [L-1 ], /* weight=*/ model->params .lnfw , /* bias=*/ model->params .lnfb ,
482
485
B, T, C);
483
486
Tensor nullTensor = createTensor (ctx, Shape{1 }, kf32);
484
487
model->nullTensor = nullTensor;
485
488
kernels.matmul_final_forward = matmul_forward (ctx, model->acts .logits , model->acts .lnf , model->params .wte , nullTensor, B, T, C, Vp);
486
489
kernels.softmax_final_forward = softmax_forward (ctx, model->acts .probs , model->acts .logits , B, T, V, Vp);
487
- kernels.crossentropy_softmax_backward = crossentropy_softmax_backward (ctx, model->acts .logits , model->acts .losses , model->acts .probs , targets, B, T, V, Vp);
488
- kernels.matmul_final_backward = matmul_backward (ctx, model->acts .lnf , model->params .wte , nullTensor, model->acts .logits ,
489
- model->acts .lnf , model->params .wte , B, T, C, Vp);
490
- kernels.layernorm_final_backward = layernorm_backward (ctx, model->acts .residual3 [L-1 ], model->params .lnfw , model->params .lnfb ,
491
- model->acts .lnf , model->acts .residual3 [L-1 ], model->params .lnfw ,
492
- model->acts .lnf_mean , model->acts .lnf_rstd , B, T, C);
490
+ if (model->backward_enabled )
491
+ kernels.crossentropy_softmax_backward = crossentropy_softmax_backward (ctx, model->acts .logits , model->acts .losses , model->acts .probs , targets, B, T, V, Vp);
492
+ if (model->backward_enabled )
493
+ kernels.matmul_final_backward = matmul_backward (ctx, model->acts .lnf , model->params .wte , nullTensor, model->acts .logits ,
494
+ model->acts .lnf , model->params .wte , B, T, C, Vp);
495
+ if (model->backward_enabled )
496
+ kernels.layernorm_final_backward = layernorm_backward (ctx, model->acts .residual3 [L-1 ], model->params .lnfw , model->params .lnfb ,
497
+ model->acts .lnf , model->acts .residual3 [L-1 ], model->params .lnfw ,
498
+ model->acts .lnf_mean , model->acts .lnf_rstd , B, T, C);
493
499
printf (" Created Kernels\n " );
494
500
}
495
501
@@ -557,7 +563,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
557
563
{
558
564
std::promise<void > promise;
559
565
std::future<void > future = promise.get_future ();
560
- dispatchKernel (ctx, model->kernels .layernorm2_backward [l], promise);
566
+ dispatchKernel (ctx, model->kernels .layernorm_forward [l], promise);
561
567
wait (ctx, future);
562
568
}
563
569
printf (" [Forward] : FF Up\n " );
@@ -1061,9 +1067,11 @@ int main() {
1061
1067
toGPU (ctx, train_loader.inputs , inputs);
1062
1068
toGPU (ctx, train_loader.targets , targets);
1063
1069
gpt2_forward (ctx, &model, inputs, targets, B, T);
1064
- gpt2_zero_grad (&model);
1065
- gpt2_backward (ctx, &model);
1066
- gpt2_update (ctx, &model, 1e-4f , 0 .9f , 0 .999f , 1e-8f , 0 .0f , step+1 );
1070
+ if (model.backward_enabled ) {
1071
+ gpt2_zero_grad (&model);
1072
+ gpt2_backward (ctx, &model);
1073
+ gpt2_update (ctx, &model, 1e-4f , 0 .9f , 0 .999f , 1e-8f , 0 .0f , step+1 );
1074
+ }
1067
1075
clock_gettime (CLOCK_MONOTONIC, &end);
1068
1076
double time_elapsed_s = (end.tv_sec - start.tv_sec ) + (end.tv_nsec - start.tv_nsec ) / 1e9 ;
1069
1077
printf (" step %d: train loss %f (took %f ms)\n " , step, model.mean_loss , time_elapsed_s * 1000 );
0 commit comments