forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_gpt2.cu
2702 lines (2409 loc) · 120 KB
/
train_gpt2.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
GPT-2 Transformer Neural Net trained in raw CUDA
Non-trivial notes to be aware of:
We are being clever in the backward pass to conserve memory.
In particular, all parameters use a += in the backward pass, so we
can later do gradient accumulation. But all activations have = instead of +=
because these are faster (just read, no write). This is okay for all activations
except for those in the residual stream, where the gradients have to add. We make
sure that those parts work out ok and that we do a += as necessary. E.g.,
the layernorms are connected to the residuals so we += in layernorm backward.
In this file we are using Mixed Precision training, so different activations,
paramaters, grads and buffers may be kept at different precisions, to take
advantage of the fast low-precision hardware in the latest GPUs (bf16/fp16),
and fp8 (coming soon^TM).
Compile:
make train_gpt2cu
Example launch using bfloat16 on 1 GPU batch size 8, sample/eval every 200 steps:
Also we're using TinyStories here for example as it is a bigger dataset
./train_gpt2cu -b 8 -v 200 -s 200 -i data/TinyStories
Example launch using bfloat16 on 4 GPUs, same as above:
mpirun -np 4 ./train_gpt2cu -b 8 -v 200 -s 200 -i data/TinyStories
If you'd like to see train_gpt2.cu produce identical results to
`python train_gpt2.py`, you can run it like this:
make train_gpt2cu PRECISION=FP32
./train_gpt2cu -b 4 -t 64 -l 1e-4 -v 200 -s 200 -a 1 -x 10 -f 0
This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),
-a 1 is "overfit single batch", -x 10 is 10 iterations, and -f 0 disables tf32
*/
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <math.h>
#include <time.h>
#include <assert.h>
#include <float.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
// GPU / CUDA related
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
// Multi-GPU related
#ifdef MULTI_GPU
#include <mpi.h>
#include <nccl.h>
#endif
// our own utilities
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
#include "utils.h"
// defines: tokenizer_init, tokenizer_decode, tokenizer_free
#include "tokenizer.h"
// ----------------------------------------------------------------------------
// CUDA precision settings
enum PrecisionMode {
PRECISION_FP32,
PRECISION_FP16,
PRECISION_BF16
};
// Default Properties
typedef float floatN;
#define CUBLAS_LOWP_COMPUTE cublas_compute_type
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatN = ncclFloat;
#endif
// Specific configurations based on the enabled precision
#if defined(ENABLE_FP32)
typedef float floatX;
#define CUBLAS_LOWP CUDA_R_32F
#define PRECISION_MODE PRECISION_FP32
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp32";
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclFloat;
#endif
// use fp16 (note: this may require gradient scaler, currently not implemented!)
#elif defined(ENABLE_FP16)
typedef half floatX;
#define CUBLAS_LOWP CUDA_R_16F
#define PRECISION_MODE PRECISION_FP16
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp16";
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclHalf;
#endif
#else // Default to bfloat16
typedef __nv_bfloat16 floatX;
#define CUBLAS_LOWP CUDA_R_16BF
#define PRECISION_MODE PRECISION_BF16
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights specific filename
const char* precision_mode_str = "bf16";
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclBfloat16;
#endif
#endif
#ifdef ENABLE_CUDNN
#include <cudnn_frontend.h>
namespace fe = cudnn_frontend;
#if CUBLAS_LOWP == CUDA_R_16BF
#define CUDNN_16BIT fe::DataType_t::BFLOAT16
#else
#define CUDNN_16BIT fe::DataType_t::HALF
#endif
static cudnnHandle_t cudnn_handle;
static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!)
static void* cudnn_workspace = NULL;
#define checkCudnnErr(err) assert((int)err == 0);
#endif // ENABLE_CUDNN
// ----------------------------------------------------------------------------
// CUDA utils
// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK
static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
static void* cublaslt_workspace = NULL;
static cublasComputeType_t cublas_compute_type;
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
int cuda_arch_major = 0;
int cuda_arch_minor = 0;
int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM
namespace cg = cooperative_groups;
// convenience macro for calculating grid/block dimensions for kernels
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
// CUDA error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
// cuBLAS error checking
void cublasCheck(cublasStatus_t status, const char *file, int line)
{
if (status != CUBLAS_STATUS_SUCCESS) {
printf("[cuBLAS ERROR]: %d %s %d\n", status, file, line);
exit(EXIT_FAILURE);
}
}
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }
#ifdef MULTI_GPU
void nccl_check(ncclResult_t status, const char *file, int line) {
if (status != ncclSuccess) {
printf("[NCCL ERROR] at file %s:%d:\n%s\n", file, line, ncclGetErrorString(status));
exit(EXIT_FAILURE);
}
}
#define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__))
void mpi_check(int status, const char *file, int line) {
if (status != MPI_SUCCESS) {
char mpi_error[4096];
int mpi_error_len = 0;
assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS);
printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, mpi_error);
exit(EXIT_FAILURE);
}
}
#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__))
#endif
// GPU helper functions for atomicAdd on smaller than 32-bit types
#ifdef ENABLE_BF16
__device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
__nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3));
// Prepare the value to add, setting the other half to zero
__nv_bfloat162 add_val = (ptr_val & 0x3) ? __halves2bfloat162(__ushort_as_bfloat16(0), val)
: __halves2bfloat162(val, __ushort_as_bfloat16(0));
atomicAdd(ptr_bf16, add_val);
}
#endif
#ifdef ENABLE_FP16
__device__ void atomicAddX(half* addr, half val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
half2* ptr_fp16 = reinterpret_cast<half2*>(ptr_val & ~uintptr_t(0x3));
// Prepare the value to add, setting the other half to zero
half2 add_val = (ptr_val & 0x3) ? __halves2half2(__ushort_as_half(0), val)
: __halves2half2(val, __ushort_as_half(0));
atomicAdd(ptr_fp16, add_val);
}
#endif
__device__ void atomicAddX(float* addr, float val) {
atomicAdd(addr, val);
}
// ----------------------------------------------------------------------------
// Packed128 data structure, which forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
// This is a bit similar to the use of float4 in the case of 32-bit floats, but
// supports arbitrary precision.
template<class ElementType>
struct alignas(16) Packed128 {
__device__ Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}
__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
return bits;
}
static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
ElementType payload[size];
};
// short-form typedef
typedef Packed128<float> f128;
typedef Packed128<floatX> x128;
// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}
// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}
// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}
// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}
// ----------------------------------------------------------------------------
// Random Number Generatiom
// Simple xorshift RNG
__device__ __host__ unsigned int random_u32(unsigned long long *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
__device__ __host__ float random_f32(unsigned long long *state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}
// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5)
// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU
// todo - possibly overkill and we don't need such high quality random numbers? (tbd)
// http://eiserloh.net/noise/SquirrelNoise5.hpp
__device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigned int seed)
{
constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111
constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111
constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011
constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011
constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101
unsigned int mangledBits = (unsigned int) positionX;
mangledBits *= SQ5_BIT_NOISE1;
mangledBits += seed;
mangledBits ^= (mangledBits >> 9);
mangledBits += SQ5_BIT_NOISE2;
mangledBits ^= (mangledBits >> 11);
mangledBits *= SQ5_BIT_NOISE3;
mangledBits ^= (mangledBits >> 13);
mangledBits += SQ5_BIT_NOISE4;
mangledBits ^= (mangledBits >> 15);
mangledBits *= SQ5_BIT_NOISE5;
mangledBits ^= (mangledBits >> 17);
return mangledBits;
}
__device__ __host__ constexpr unsigned int Get1dNoiseUint(int positionX, unsigned int seed)
{
return SquirrelNoise5(positionX, seed);
}
__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed)
{
constexpr int PRIME_NUMBER = 198491317; // Large prime number with non-boring bits
return SquirrelNoise5(indexX + (PRIME_NUMBER * indexY), seed);
}
__device__ __host__ constexpr float Get1dNoiseZeroToOne(int index, unsigned int seed)
{
constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF);
return (float)(ONE_OVER_MAX_UINT * (double) SquirrelNoise5(index, seed));
}
__device__ __host__ constexpr float Get2dNoiseZeroToOne(int indexX, int indexY, unsigned int seed)
{
constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF);
return (float)(ONE_OVER_MAX_UINT * (double) Get2dNoiseUint(indexX, indexY, seed));
}
// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift)
__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) {
// todo - is this stochastic rounding *too good*? can we cut any corners?
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
unsigned int threshold = random & 0xFFFF;
unsigned int float_bits = __float_as_uint(in);
unsigned int rounded_bits = float_bits & 0x0000FFFF;
float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF);
*out = __float2bfloat16_rn(__uint_as_float(float_bits));
}
__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) {
*out = (float)in; // todo - implement this...
}
__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) {
*out = in; // dummy function for when floatX is float (FP32 mode)
}
// ----------------------------------------------------------------------------
// MPI / multi-processing setup
// Parameters specific to training on multiple GPUs.
typedef struct {
int process_rank; // Rank of this process among all MPI processes. 0 if no multi-GPU.
int num_processes; // Total number of processes. 1 if no multi-GPU.
int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU.
#ifdef MULTI_GPU
ncclComm_t nccl_comm; // NCCL communication primitive, used for collective multi-GPU work.
#endif
} MultiGpuConfig;
// one global variable to hold the multi-GPU configuration for this process
MultiGpuConfig multi_gpu_config;
#ifdef MULTI_GPU
// Determine which GPU this process should use.
// Processes on the same machines use different GPU indicies. Processes on other machines don't.
// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread
int multi_gpu_get_local_device_idx(int process_rank, int num_processes) {
char hostname[1024];
hostname[1023] = '\0';
// All processes on the same machine will share the same hostname.
gethostname(hostname, 1023);
for (int i=0; i < 1024; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
break;
}
}
uint64_t hostname_hash = 5381;
for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; }
// Distribute all hostname hashes to all processes.
uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t));
all_hostsname_hashes[process_rank] = hostname_hash;
mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
// Identify which GPU we need to use.
int local_device_idx = 0;
for (int current_process = 0; current_process < num_processes; ++current_process) {
if (current_process == process_rank) {
// Found my gpu, local_device_idx now has my target GPU index.
break;
}
if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) {
// This process ID runs on the same machine, but it's not me, skip this GPU
local_device_idx++;
}
}
free(all_hostsname_hashes);
return local_device_idx;
}
#endif
MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) {
#ifdef MULTI_GPU
// Initialize MPI.
MultiGpuConfig result;
mpiCheck(MPI_Init(argc, argv));
mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank));
mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes));
result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes);
cudaCheck(cudaSetDevice(result.local_device_idx));
ncclUniqueId nccl_id;
if (result.process_rank == 0) {
ncclCheck(ncclGetUniqueId(&nccl_id));
}
mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD));
ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank));
return result;
#else
printf("Multi-GPU support is disabled. Using a single GPU.\n");
MultiGpuConfig result;
result.process_rank = 0;
result.num_processes = 1;
result.local_device_idx = 0;
return result;
#endif
}
void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm));
mpiCheck(MPI_Finalize());
#endif
}
// convenience function that only prints if the rank of process is zero
void printf0(const char *format, ...) {
if (multi_gpu_config.process_rank == 0) {
va_list args;
va_start(args, format);
vprintf(format, args);
va_end(args);
}
}
// ----------------------------------------------------------------------------
// cuDNN path
#ifdef ENABLE_CUDNN
using graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>>; // Stats
using graph_tensors_bwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,
std::shared_ptr<fe::graph::Tensor_attributes>>; // dV
// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
using cache_type_fwd = std::unordered_map<std::size_t, graph_tensors_fwd>;
using cache_type_bwd = std::unordered_map<std::size_t, graph_tensors_bwd>;
// Loosely based on cuDNN frontend samples functions and massively simplified
template <typename... Args>
auto lookup_cache_or_build_graph_fwd(Args... args) {
static cache_type_fwd user_maintained_cache_fwd;
auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...);
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(CUDNN_16BIT)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
// QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
auto Q = graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({B, H, T, HS})
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({B, H, T, HS})
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({B, H, T, HS})
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention");
sdpa_options.set_is_inference(is_inference_only);
sdpa_options.set_attn_scale(attn_scale);
sdpa_options.set_causal_mask(true);
// Create the graph operation and get the output tensors back
auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);
// Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1});
assert(stats == nullptr || is_inference_only == false);
if (is_inference_only == false) {
stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({B, H, T, 1})
.set_stride({H * T, T, 1, 1});
}
assert(graph->validate().is_good());
auto key = graph->key();
auto it = user_maintained_cache_fwd.find(key);
if (it != user_maintained_cache_fwd.end()) {
return it->second;
}
// Build the operation graph and execution part (this is the VERY SLOW PART)
assert(graph->build_operation_graph(cudnn_handle).is_good());
auto plans = graph->create_execution_plans({fe::HeurMode_t::A});
assert(graph->check_support(cudnn_handle).is_good());
assert(graph->build_plans(cudnn_handle).is_good());
auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats);
user_maintained_cache_fwd.insert({key, tuple});
return tuple;
}
template <typename... Args>
auto lookup_cache_or_build_graph_bwd(Args... args) {
static cache_type_bwd user_maintained_cache_bwd;
auto [B, NH, T, HS] = std::make_tuple(args...);
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(CUDNN_16BIT)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
// (B, N, 3, NH, HS)
// must come from inp (which means we also need to convert THAT to FP16)
auto Q = graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({B, NH, T, HS})
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({B, NH, T, HS})
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({B, NH, T, HS})
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto O = graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({B, NH, T, HS})
.set_stride({NH * HS * T, HS, NH * HS, 1}));
auto dO = graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({B, NH, T, HS})
.set_stride({NH * HS * T, HS, NH * HS, 1}));
auto stats = graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({B, NH, T, 1})
.set_stride({NH * T, T, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(true)
.set_attn_scale(attn_scale);
// Create the graph operation and get the output tensors back
auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options);
dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
assert(graph->validate().is_good());
auto key = graph->key();
auto it = user_maintained_cache_bwd.find(key);
if (it != user_maintained_cache_bwd.end()) {
return it->second;
}
// Build the operation graph and execution part (this is the VERY SLOW PART)
assert(graph->build_operation_graph(cudnn_handle).is_good());
auto plans = graph->create_execution_plans({fe::HeurMode_t::A});
assert(graph->check_support(cudnn_handle).is_good());
assert(graph->build_plans(cudnn_handle).is_good());
auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV);
user_maintained_cache_bwd.insert({key, tuple});
return tuple;
}
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C) {
int HS = C / NH; // number of features per head
bool is_inference_only = (stats == nullptr);
// Get graph and tensors from cache (or generate it on first use)
auto [graph, Q, K, V, attn_scale, O, softmax_stats] =
lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);
// Prepare all the tensor pointers for executing the graph
void* devPtrQ = inp;
void* devPtrK = (inp + C);
void* devPtrV = (inp + 2 * C);
float attn_scale_cpu = 1.0 / sqrtf(HS);
void* devPtrO = out;
// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}};
// Add the stats tensor unless we are only doing inference (only needed for backward pass)
if (is_inference_only == false) {
variant_pack[softmax_stats] = stats;
}
// Reallocate the workspace if the required size is greater than the current workspace
// By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum
if (graph->get_workspace_size() > cudnn_workspace_size) {
if (cudnn_workspace_size > 0) {
cudaCheck(cudaFree(cudnn_workspace));
}
cudnn_workspace_size = graph->get_workspace_size();
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
}
// Execute graph
assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good());
cudaCheck(cudaGetLastError());
}
void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C) {
int HS = C / NH; // number of features per head
// Get graph and tensors from cache (or generate it on first use)
auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] =
lookup_cache_or_build_graph_bwd(B, NH, T, HS);
// Prepare all the tensor pointers for executing the graph
void* devPtrQ = qkvr;
void* devPtrK = (qkvr + NH * HS);
void* devPtrV = (qkvr + 2 * NH * HS);
void* devPtrO = o;
void* devPtrdO = dout;
void* devPtrStats = stats;
float attn_scale_cpu = 1.0 / sqrtf(HS);
void* devPtrdQ = dqkvr;
void* devPtrdK = (dqkvr + NH * HS);
void* devPtrdV = (dqkvr + 2 * NH * HS);
// Build variant pack that links each tensor to its data pointer
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats},
{dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV},
{attn_scale, &attn_scale_cpu}};
// Reallocate the workspace if the required size is greater than the current workspace
// By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum
if (graph->get_workspace_size() > cudnn_workspace_size) {
if (cudnn_workspace_size > 0) {
cudaCheck(cudaFree(cudnn_workspace));
}
cudnn_workspace_size = graph->get_workspace_size();
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
}
// Execute graph
assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good());
cudaCheck(cudaGetLastError());
}
#endif // ENABLE_CUDNN
// ----------------------------------------------------------------------------
// all the kernels
__global__ void encoder_forward_kernel2(floatX* out,
int* inp, floatX* wte, floatX* wpe,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C;
if (idx < N) {
int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c = idx % C;
int ix = inp[b * T + t];
floatX* out_btc = out + b * T * C + t * C + c;
floatX* wte_ix = wte + ix * C + c;
floatX* wpe_tc = wpe + t * C + c;
*out_btc = (floatX)((float)*wte_ix + (float)*wpe_tc);
}
}
// really bad naive kernel with atomicAdd
__global__ void encoder_backward_kernel(floatX* dwte, floatX* dwpe,
const floatX* dout, const int* inp,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C;
if (idx < N) {
int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c = idx % C;
int ix = inp[b * T + t];
const floatX* dout_btc = dout + b * T * C + t * C + c;
floatX* dwte_ix = dwte + ix * C + c;
floatX* dwpe_tc = dwpe + t * C + c;
atomicAddX(dwte_ix, (floatX)*dout_btc);
atomicAddX(dwpe_tc, (floatX)*dout_btc);
}
}
__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd,
const floatX* __restrict__ inp, const floatX* __restrict__ weight,
const floatX* __restrict__ bias, int N, int C) {
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= N) { return; } // guard
// the row of input that this group of threads is responsible for
const floatX* x = inp + idx * C;
// mean
float sum = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
sum += (float)x[i];
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
float m = sum / C;
if(warp.thread_rank() == 0 && mean != nullptr) {
__stcs(mean + idx, (floatX)m);
}
// rstd
sum = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float diff = (float)x[i] - m;
sum += diff * diff;
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
float s = rsqrtf(sum / C + 1e-5f);
if(warp.thread_rank() == 0 && rstd != nullptr) {
__stcs(rstd + idx, (floatX)s);
}
// final normalization and scaling by weight/bias
floatX* o = out + idx * C;
for (int c = warp.thread_rank(); c < C; c += warp.size()) {
// load and store using the .cs "streaming" hint to the compiler,
// indicating that this data will not be reused soon, and can be streamed through the caches
// this allows the threads to get more cache-hits for the (shared) weight and bias parameters
float n = s * ((float)__ldcs(x+c) - m);
__stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c]));
}
}
// inputs floatX, outputs FP32 (for current FP32-only activation path for this WIP)
__global__ void permute_kernel(floatX* q, floatX* k, floatX* v,
const floatX* inp,
int B, int N, int NH, int d) {
// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
q[idx] = __ldcs(&inp[inp_idx]);
k[idx] = __ldcs(&inp[inp_idx + NH * d]);
v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);
}
}
__global__ void permute_kernel_backward(floatX* dinp,
const floatX* dq, const floatX* dk, const floatX* dv,
int B, int N, int NH, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
dinp[inp_idx] = dq[idx];
dinp[inp_idx + NH * d] = dk[idx];
dinp[inp_idx + 2 * (NH * d)] = dv[idx];
}
}
__global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int d) {
// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
int idx = (blockIdx.x * blockDim.x + threadIdx.x);
// out[b][n][nh_][d_] <- inp[b][nh_][n][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
out[other_idx] = __ldcs(&inp[idx]);
}
}
__global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int B, int N, int NH, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
dinp[idx] = (floatX)dout[other_idx];
}
}
__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, const floatX* inp, int N, int T) {
// inp, out shape: (N, T, T), where N = B * NH
// fuses the multiplication by scale inside attention
// directly autoregressive, so we only compute the lower triangular part
// uses the online softmax algorithm
assert(T % 4 == 0);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// micro-optimization: we iterate backwards so that
// after the softmax backward operation completes, the cache retains the
// part of the matrix close to the upper left corner, which benefits the
// matmul operation that immediately follows.
// int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // forward order
int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank(); // backward order
if(idx >= N * T) {
return;
}
int own_pos = idx % T;
int pos_by_4 = own_pos / 4;
// one row of inp, i.e. inp[idx, :] of shape (T,)
const floatX* x = inp + idx * T;
// not INF, so we don't get NaNs accidentally when subtracting two values.
float maxval = -FLT_MAX;
float sumval = 0.0f;
const floatX* x_aligned = reinterpret_cast<const floatX*>(__builtin_assume_aligned(x, 16));
for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {
float regarray[4];
#pragma unroll
for (int k = 0; k < 4; ++k) {
regarray[k] = (float)x_aligned[4*i + k];
}
float old_maxval = maxval;
for(int k = 0; k < 4; ++k) {
maxval = fmaxf(maxval, regarray[k]);
}
sumval *= expf(inv_temperature * (old_maxval - maxval));
for(int k = 0; k < 4; ++k) {
sumval += expf(inv_temperature * (regarray[k] - maxval));
}
}
if(4*pos_by_4 + warp.thread_rank() <= own_pos) {
float old_maxval = maxval;
maxval = fmaxf(maxval, (float)x[4*pos_by_4 + warp.thread_rank()]);
sumval *= expf(inv_temperature * (old_maxval - maxval));
sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + warp.thread_rank()] - maxval));
}
float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});
sumval *= expf(inv_temperature * (maxval - global_maxval));
float sum = cg::reduce(warp, sumval, cg::plus<float>{});
float norm = 1.f / sum;
// divide the whole row by the sum
for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {
// recalculation is faster than doing the round-trip through memory.
float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval));
__stcs(out + idx * T + i, (floatX)(ev * norm));
}
}
__global__ void residual_forward_kernel(floatX* out, floatX* inp1, floatX* inp2, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
out[idx] = (floatX)((float)__ldcs(&inp1[idx]) + (float)__ldcs(&inp2[idx]));
}
}
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (i < N) {
x128 packed_out;
x128 packed_inp = load128cs(inp + i); // load and do not keep in cache
for(int k = 0; k < packed_inp.size; ++k) {
float xi = (float)packed_inp[k];
float cube = 0.044715f * xi * xi * xi;
packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));
}
// store instead of storecs (without cache streaming) in case it is useful for the
// data to be in the cache for the next operation after this GeLU
store128(out + i, packed_out);
}
}
__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float x = (float)inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] = (floatX)(local_grad * (float)dout[i]);
}
}
// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:
// dbias = dout.sum((0,1))
// the idea is to employ one block to reduce along several columns,
// where each block has a width of 32 columns to ensure coalesced access.
// at the end we accumulate the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout, int B, int T, int OC) {
// this kernel is launched with 1D grid_dim of OC/32
// for example let's say block_size is 128
extern __shared__ float smem[]; // of size block_size (128)
const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3
const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31
const int tl = blockIdx.x * warpSize; // pointer to the start column for this block
const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4
// pointer to the start of the column for one lane of threads
// so e.g. 4 threads (of the same lane_id) will reduce this one column
const floatX* dout_col = dout + tl + lane_id;
// column reductions by looping through the rows
// each of the 4 threads offsets by its warp_id and then skips by vstep
// together these 4 threads cover all B*T rows of this (lane_id) column
// importantly, consecutive threads (in threadId) are processing adjacent columns,
// leading to a coalesced memory access pattern
float dout_sum = 0.0f;
for (int row = warp_id; row < B * T; row += vstep) {
dout_sum += (float)dout_col[row * OC];