-
Notifications
You must be signed in to change notification settings - Fork 677
/
Copy pathkernels.cu
3909 lines (3346 loc) · 156 KB
/
kernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <kernels.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#include <mma.h>
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
int* address_as_i = reinterpret_cast<int*>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(
reinterpret_cast<int*>(address), assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__device__ float atomicMin(float* address, float val) {
int* address_as_i = reinterpret_cast<int*>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(
reinterpret_cast<int*>(address), assumed,
__float_as_int(fminf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__device__ float dDequantizeFP4(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0110) == 0)
{
// subnormal
if((val & 0b0001) == 0)
return 0.0f;
else
return sign*0.0625f*absmax;
}
else
{
// normal
float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
return sign*exponent*fraction*absmax;
}
}
__device__ float d2DequantizeFP4(unsigned char val)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0110) == 0)
{
// subnormal
if((val & 0b0001) == 0)
return 0.0f;
else
return sign*0.0625f;
}
else
{
// normal
float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
return sign*exponent*fraction;
}
}
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else
return 0.33333333f*absmax*sign; // 1100
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
}
__device__ unsigned char dQuantizeFP4(float x)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assume input data is in [-1.0, 1.0]
// !be careful here, its easy to make a mistake
// that is difficult to notice if you add an extra
// zero somewhere!
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 0.29166667f)
if( x > 0.583333f)
if( x > 0.8333333f)
return 0b0011+sign;
else
return 0b0010+sign;
else
if(x > 0.4166667f)
return 0b101+sign;
else
return 0b100+sign;
else
if(x > 0.0859375f)
if(x > 0.20833333f)
return 0b0111+sign;
else
return 0b0110+sign;
else
if(x > 0.00260417f)
return 0b0001+sign;
else
return 0b0000+sign;
}
__device__ half dhDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ float dDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ unsigned char dQuantizeNF4(float x)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
if(x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
if(x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
if(x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
if(x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T> __device__ int sgn(T val)
{
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = -1.0f;
float upper = 1.0f;
float val = smem_code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
}
val = smem_code[pivot];
}
if(upper_pivot == 255)
upper = smem_code[upper_pivot];
if(lower_pivot == 0)
lower = smem_code[lower_pivot];
if(!STOCHASTIC)
{
if(x > val)
{
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
return upper_pivot;
}
else
return pivot;
}
else
{
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
else
{
if(x > val)
{
float dist_to_upper = fabsf(upper-x);
float dist_full = upper-val;
if(rand >= dist_to_upper/dist_full) return upper_pivot;
else return pivot;
}
else
{
float dist_to_lower = fabsf(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
}
}
}
template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = SIGNED ? -1.0f : 0.0f;
float upper = 1.0f;
float midpoint;
float val = quadrants[1];
int local_pivot = 1;
int offset = 1;
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
//val = i == 64 ? quadrants[2] : smem_code[pivot];
local_pivot += offset;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
//val = i == 64 ? quadrants[0] : smem_code[pivot];
local_pivot -= offset;
}
val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
offset -= 1;
}
if(x > val)
{
midpoint = (upper+val)*0.5f;
if(x > midpoint)
return upper_pivot;
else
return pivot;
}
else
{
midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
const int numThreads = blockDim.x*gridDim.x;
for(int i = tid; i < n; i+=numThreads)
{
int idx = (index1[i]*maxidx1) + index2[i];
atomicAdd(&histogram[idx], src[i]);
}
}
#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096
template<typename T>
__launch_bounds__(THREADS_ESTIMATE, 1)
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
{
const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));
T vals[NUM_ESTIMATE];
typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ union {
typename LoadFloat::TempStorage loadf;
typename BlockRadixSort::TempStorage sort;
int smem_qidx[BLOCK_ESTIMATE];
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
{
valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;
// do not process half-blocks
if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = max_val;
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = ((float)vals[j]) * reciprocal_num_blocks;
__syncthreads();
// sort into striped pattern to mitigate bank conflicts
// striped pattern index for thread 0 [0, 1024, 2048, 3096]
// striped pattern index for thread 1 [1, 1025, 2049, 3097]
BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);
__syncthreads();
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
temp_storage.smem_qidx[j] = -1;
__syncthreads();
if(threadIdx.x < 256)
{
float q_interval = (1.0f-(2.0f*offset))/255.0f;
int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
temp_storage.smem_qidx[local_idx] = threadIdx.x;
}
__syncthreads();
for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
{
if(temp_storage.smem_qidx[i] != -1)
atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
}
}
}
__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
const int base_idx = (blockIdx.x * NUM_BLOCK);
float vals[NUM];
unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2;
typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ float smem_code[256];
//__shared__ float smem_code[2][257];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
//smem_code[0][threadIdx.x] = code[threadIdx.x];
//smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
{
// number of values already processed in blocks +
// number of values already processed in this block +
// rand_offset % mod value
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
__syncthreads();
LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM; j++)
qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
__syncthreads();
StoreChar(storec).Store(&(out[i]), qvals, valid_items);
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename BlockReduce::TempStorage reduce;
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(DATA_TYPE == General8bit)
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX;
__syncthreads();
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);
// 1. compute local max
// 2. broadcast local max
// 3. normalize inputs and quantize
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
if(threadIdx.x == 0)
smem_absmax_value[0] = local_abs_max;
__syncthreads();
if(threadIdx.x == 0)
absmax[i/BLOCK_SIZE] = local_abs_max;
else
local_abs_max = smem_absmax_value[0];
__syncwarp();
local_abs_max = 1.0f/local_abs_max;
if(STOCHASTIC)
{
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
}
__syncthreads();
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
}
}
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
{
const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE);
T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)];
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
{
if(DATA_TYPE > 0)
{
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
}
else
{
valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i;
valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i;
}
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
switch(DATA_TYPE)
{
case General8bit:
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
}
break;
}
__syncthreads();
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
}
}
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
const unsigned int numThreads = blockDim.x * gridDim.x;
const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
__shared__ float smem_code[256];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
}
__syncthreads();
for (int i = idx;i < n; i += numThreads)
{
out[i] = smem_code[A[i]];
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
T g_vals[NUM_VALS];
float s1_vals[NUM_VALS];
float s2_vals[NUM_VALS];
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
typename LoadFloat::TempStorage loadf;
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
break;
}
}
# pragma unroll NUM_VALS-1
for(unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
if(threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
__syncwarp();
}
}
#define NUM_PER_THREAD 4
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD];
// AdEMAMix has an additional state buffer, which we packed
// into state1. We need thread-local storage here for these.
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
typename Store::TempStorage store;
typename LoadFloat::TempStorage loadf;
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
// Load additional state1 data for AdEMAMix
// TODO: Make constexpr after updating min compiler
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
switch(OPTIMIZER)
{
case ADEMAMIX:
// m1 update: m1 = beta1 * m1 + (1-beta1) * g
s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]);
// m2 update: m2 = m2 * beta3 + (1-beta3) * g
s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]);
// nu update: nu = beta2 * nu + (1-beta2) * g^2
s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]);
p_vals[j] = (float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
);
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
break;
case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}
}
__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items);
}
}
}