Skip to content

Commit 5ebcca0

Browse files
committed
power of 2 multiple for caching gemms
1 parent 245c9bc commit 5ebcca0

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,19 @@
1111
KEYS = ['M', 'N', 'K', 'group_size', 'elements_per_sample']
1212
MATMUL_TYPE = "GEMM"
1313

14-
# code based https://github.com/fpgaminer/GPTQ-triton
1514
def kernel_config_pruner(configs, nargs, **kwargs):
1615
global KEYS
1716
from ..core import GEMLITE_TRITON_CONFIG_CACHE
1817

19-
m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) #Need at least 16 here for tl.dot
18+
m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16)
2019
n = nargs['N']
2120
k = nargs['K']
2221
g = nargs['group_size']
2322
e = nargs['elements_per_sample']
2423

2524
#Check cache
2625
if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE):
27-
_signature = str(tuple([nargs[i] for i in KEYS]))
26+
_signature = str(tuple([m, n, k, g, e]))
2827
if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]):
2928
_config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature])
3029
_num_stages = _config.pop('num_stages')

gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def kernel_config_pruner(configs, nargs, **kwargs):
1515
global KEYS
1616
from ..core import GEMLITE_TRITON_CONFIG_CACHE
1717

18-
m = nargs['M']
18+
m = 2 ** int(math.ceil(math.log2(nargs['M'])))
1919
n = nargs['N']
2020
k = nargs['K']
2121
g = nargs['group_size']
2222
e = nargs['elements_per_sample']
2323

2424
#Check cache
2525
if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE):
26-
_signature = str(tuple([nargs[i] for i in KEYS]))
26+
_signature = str(tuple([m, n, k, g, e]))
2727
if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]):
2828
_config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature])
2929
_num_stages = _config.pop('num_stages')

0 commit comments

Comments
 (0)