Skip to content

Commit 6a20353

Browse files
authored
Update benchmark_triton.py
1 parent 0964923 commit 6a20353

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/benchmark_triton.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
pass
1717

1818
#GemLite
19-
from gemlite.core import GemLiteLinearTriton, DType, set_autotune
19+
from gemlite.core import GemLiteLinearTriton, DType, set_autotune, GEMLITE_ACC_DTYPE
2020
set_autotune({'GEMV_REVSPLITK':True, 'GEMV_SPLITK': True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True}, exhaustive=True, use_cuda_graph=False)
2121

22+
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 #For A100/H100
23+
#GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP16 #For 3090/4090
24+
2225
device = 'cuda:0'
2326
compute_dtype = torch.float16
2427

0 commit comments

Comments
 (0)