Skip to content

Commit 189375f

Browse files
Merge branch 'dev' into feature/reduce
2 parents c13833f + 0e89e65 commit 189375f

File tree

7 files changed

+1437
-432
lines changed

7 files changed

+1437
-432
lines changed

Makefile

+44-6
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,37 @@ pch:
1919
mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -x c++-header gpu.hpp -o build/gpu.hpp.pch
2020

2121
# TODO(avh): change extension based on platform
22-
lib:
23-
mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/libgpucpp.dylib
22+
# Get the current OS name
23+
OS = $(shell uname | tr -d '\n')
24+
# Set the specific variables for each platform
25+
LIB_PATH ?= /usr/lib
26+
HEADER_PATH ?= /usr/include
27+
ifeq ($(OS), Linux)
28+
OS_TYPE ?= Linux
29+
GPU_CPP_LIB_NAME ?= libgpucpp.so
30+
DAWN_LIB_NAME ?= libdawn.so
31+
else ifeq ($(OS), Darwin)
32+
OS_TYPE ?= macOS
33+
GPU_CPP_LIB_NAME ?= libgpucpp.dylib
34+
DAWN_LIB_NAME ?= libdawn.dylib
35+
else
36+
OS_TYPE ?= unknown
37+
endif
38+
39+
lib: check-clang dawnlib
40+
mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/$(GPU_CPP_LIB_NAME)
41+
python3 build.py
42+
cp third_party/lib/$(DAWN_LIB_NAME) build/
43+
44+
install:
45+
cp build/$(GPU_CPP_LIB_NAME) $(LIB_PATH)
46+
cp build/$(DAWN_LIB_NAME) $(LIB_PATH)
47+
cp build/gpu.hpp $(HEADER_PATH)
48+
49+
uninstall:
50+
rm $(LIB_PATH)/$(GPU_CPP_LIB_NAME)
51+
rm $(LIB_PATH)/$(DAWN_LIB_NAME)
52+
rm $(HEADER_PATH)/gpu.hpp
2453

2554
examples/hello_world/build/hello_world: check-clang dawnlib examples/hello_world/run.cpp check-linux-vulkan
2655
$(LIBSPEC) && cd examples/hello_world && make build/hello_world && ./build/hello_world
@@ -96,15 +125,24 @@ clean-all:
96125
# Checks
97126
################################################################################
98127

128+
# Check all
129+
check-all: check-os check-clang check-cmake check-python
130+
131+
# check the os
132+
check-os:
133+
ifeq ($(OS_TYPE), unknown)
134+
$(error Unsupported operating system)
135+
endif
136+
99137
# check for the existence of clang++ and cmake
100138
check-clang:
101-
@command -v clang++ >/dev/null 2>&1 || { echo >&2 "Please install clang++ with 'sudo apt-get install clang' or 'brew install llvm'"; exit 1; }
139+
@command -v clang++ >/dev/null 2>&1 || { echo -e >&2 "Clang++ is not installed. Please install clang++ to continue.\nOn Debian / Ubuntu: 'sudo apt-get install clang' or 'brew install llvm'\nOn Centos: 'sudo yum install clang'"; exit 1; }
102140

103141
check-cmake:
104-
@command -v cmake >/dev/null 2>&1 || { echo >&2 "Please install cmake with 'sudo apt-get install cmake' or 'brew install cmake'"; exit 1; }
142+
@command -v cmake >/dev/null 2>&1 || { echo -e >&2 "Cmake is not installed. Please install cmake to continue.\nOn Debian / Ubuntu: 'sudo apt-get install cmake' or 'brew install cmake'\nOn Centos: 'sudo yum install cmake'"; exit 1; }
105143

106144
check-python:
107-
@command -v python3 >/dev/null 2>&1 || { echo >&2 "Python needs to be installed and in your path."; exit 1; }
145+
@command -v python3 >/dev/null 2>&1 || { echo -e >&2 "Python is not installed. Please install python to continue.\nOn Debian / Ubuntu: 'sudo apt-get install python'\nOn Centos: 'sudo yum install python'"; exit 1; }
108146

109147
check-linux-vulkan:
110148
@echo "Checking system type and Vulkan availability..."
@@ -113,7 +151,7 @@ check-linux-vulkan:
113151
echo "Vulkan is installed."; \
114152
vulkaninfo; \
115153
else \
116-
echo "Vulkan is not installed. Please install Vulkan drivers to continue. On Debian / Ubuntu: sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools"; \
154+
echo -e "Vulkan is not installed. Please install Vulkan drivers to continue.\nOn Debian / Ubuntu: 'sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools'.\nOn Centos: 'sudo yum install vulkan vulkan-tools.'"; \
117155
exit 1; \
118156
fi \
119157
else \

build.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Dictionary of header files and their relative paths
2+
header_files = {
3+
"#include \"webgpu/webgpu.h\"": "third_party/headers/webgpu/webgpu.h",
4+
"#include \"numeric_types/half.hpp\"": "numeric_types/half.hpp",
5+
"#include \"utils/logging.hpp\"": "utils/logging.hpp"
6+
}
7+
8+
def main():
9+
# File paths
10+
source_file_path = "gpu.hpp"
11+
output_file_path = "build/gpu.hpp"
12+
13+
# Open source file and read contents
14+
with open(source_file_path, "r") as source:
15+
file_contents = source.read()
16+
17+
# Ergodic over header files
18+
for key, value in header_files.items():
19+
20+
# Replace header files
21+
with open(value, "r") as header_file:
22+
header_file_contents = header_file.read()
23+
file_contents = file_contents.replace(key, header_file_contents)
24+
25+
26+
# Open output file
27+
with open(output_file_path, "w") as output:
28+
# Write contents to output file
29+
output.write(file_contents)
30+
31+
if __name__ == "__main__":
32+
main()

experimental/kernels/Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ endef
8383
build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
8484
mkdir -p build
8585
$(call preprocess_file)
86-
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o
86+
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g
8787

8888
build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin
8989
mkdir -p build
@@ -94,12 +94,12 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi
9494
build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin
9595
mkdir -p build
9696
$(call preprocess_file)
97-
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o
97+
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g
9898

9999
build/ops.o: ops.cpp ops.hpp kernels.h llm.c
100100
mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $<
101101

102-
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c
102+
build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp
103103
mkdir -p build
104104
$(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp
105105

experimental/kernels/kernels.h

+98
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
309309
}
310310
}
311311
}
312+
313+
)";
314+
315+
316+
static const char *kShaderMatmul2DTiling = R"(
317+
@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>;
318+
@group(0) @binding(1) var<storage, read_write> weight : array<{{precision}}>;
319+
@group(0) @binding(2) var<storage, read_write> bias : array<{{precision}}>;
320+
@group(0) @binding(3) var<storage, read_write> out : array<{{precision}}>;
321+
@group(0) @binding(4) var<uniform> params : Params;
322+
struct Params {
323+
B: u32,
324+
T: u32,
325+
C: u32,
326+
OC: u32,
327+
};
328+
var<workgroup> tileInp: array<{{precision}}, {{BT}} * {{BC}}>;
329+
var<workgroup> tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>;
330+
331+
@compute @workgroup_size({{workgroupSize}})
332+
fn main(
333+
@builtin(local_invocation_id) localID : vec3<u32>,
334+
@builtin(workgroup_id) groupid : vec3<u32>) {
335+
let B : u32 = params.B;
336+
let T : u32 = params.T;
337+
let C : u32 = params.C;
338+
let OC : u32 = params.OC;
339+
340+
var localT: array<{{precision}}, {{TT}}>;
341+
var localOC: array<{{precision}}, {{TOC}}>;
342+
343+
let outB: u32 = groupid.x;
344+
let outT: u32 = groupid.y;
345+
let outOC: u32 = groupid.z;
346+
let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}});
347+
348+
// position of the first c element computed by the thread
349+
let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}};
350+
let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}};
351+
352+
// inpPtr and weightPtr are the starting positions of the tiles in a and b,
353+
// incremented in the bkidx loop.
354+
// outPtr is the starting position of the tile in c which is fixed.
355+
356+
var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC
357+
var weightPtr = outOC * {{BOC}} * C; //OCC
358+
var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>;
359+
let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC
360+
let biasPtr = outOC * {{BOC}};
361+
362+
for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) {
363+
// Load BC x BOC by numThread(BT * BOC / (TT * TOC))
364+
// The number of iteration == BC * BOC / (BT * BOC / (TT * TOC))
365+
for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) {
366+
tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})];
367+
}
368+
weightPtr += {{BC}};
369+
370+
// Load tile
371+
// Load BT x BC by numThread(BT * BOC / (TT * TOC))
372+
// The number of iteration == BT * BC / (BT * BOC / (TT * TOC))
373+
for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) {
374+
tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}];
375+
}
376+
inpPtr += {{BC}};
377+
378+
workgroupBarrier();
379+
// Compute tile
380+
for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) {
381+
for (var idx: u32 = 0; idx < {{TT}}; idx++) {
382+
localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx];
383+
}
384+
for (var idx: u32 = 0; idx < {{TOC}}; idx++) {
385+
localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx];
386+
}
387+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
388+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
389+
threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC];
390+
}
391+
}
392+
}
393+
workgroupBarrier();
394+
}
395+
396+
if (arrayLength(&bias) == 1) {
397+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
398+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
399+
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC];
400+
}
401+
}
402+
} else {
403+
for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) {
404+
for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) {
405+
out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC];
406+
}
407+
}
408+
}
409+
}
312410
)";
313411

314412
static const char *kShaderMatmulBackward = R"(

0 commit comments

Comments
 (0)