Skip to content

Commit 890054c

Browse files
committed
第10节课程
1 parent 991df7f commit 890054c

File tree

10 files changed

+86
-26
lines changed

10 files changed

+86
-26
lines changed

demo/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ int32_t generate(const model::LLama2Model& model, const std::string& sentence, i
1717
pos_tensor.index<int32_t>(0) = pos;
1818
if (pos < prompt_len - 1) {
1919
tensor::Tensor input = model.fill_input(pos_tensor, prompt_embedding, is_prompt);
20-
next = model.forward(input, pos_tensor, is_prompt, next);
20+
model.predict(input, pos_tensor, is_prompt, next);
2121
} else {
2222
is_prompt = false;
2323
tokens = std::vector<int32_t>{next};
2424
const auto& token_embedding = model.embedding(tokens);
2525
tensor::Tensor input = model.fill_input(pos_tensor, token_embedding, is_prompt);
26-
model.forward(input, pos_tensor, is_prompt, next);
26+
model.predict(input, pos_tensor, is_prompt, next);
2727
}
2828
if (next == model.get_eos()) {
2929
break;

imgs/qa.jpg

1.02 MB
Loading

kuiper/include/model/llama2.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ class LLama2Model : public Model {
3636

3737
base::Status init(base::DeviceType device_type) override;
3838

39-
base::Status forward(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
39+
base::Status predict(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
4040
bool is_prompt, int& next) const override;
4141

42+
base::Status forward(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
43+
int& next) const override;
44+
4245
std::vector<int32_t> encode(const std::string& sentence) const override;
4346

4447
int32_t get_eos() const override;

kuiper/include/model/model.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ class Model {
1818

1919
virtual base::Status init(base::DeviceType device_type) = 0;
2020

21-
virtual base::Status forward(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
21+
virtual base::Status predict(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
2222
bool is_prompt, int& next) const = 0;
2323

24+
virtual base::Status forward(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
25+
int& next) const = 0;
26+
2427
virtual int32_t get_eos() const = 0;
2528

2629
base::ModelType model_type() const;

kuiper/source/model/llama2.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ base::Status LLama2Model::init(base::DeviceType device_type) {
131131
}
132132

133133
base::Status LLama2Model::forward(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
134-
bool is_prompt, int& next) const {
134+
int& next) const {
135135
if (input.is_empty()) {
136136
return base::error::InvalidArgument("The input tensor is empty.");
137137
}
@@ -149,7 +149,6 @@ base::Status LLama2Model::forward(const tensor::Tensor& input, const tensor::Ten
149149
feed_forward(layer_idx, input);
150150
}
151151
cls_logits(input);
152-
next = post_processing(pos_tensor, is_prompt);
153152
return base::error::Success();
154153
}
155154

@@ -674,6 +673,16 @@ void LLama2Model::attention_qkv(int32_t layer_idx, const tensor::Tensor& pos_ten
674673
STATUS_CHECK(llama_layers_->rope_layer_->forward(query, key, pos_tensor, tensor::Tensor{}));
675674
}
676675

676+
base::Status LLama2Model::predict(const tensor::Tensor& input, const tensor::Tensor& pos_tensor,
677+
bool is_prompt, int& next) const {
678+
auto status = forward(input, pos_tensor, next);
679+
if (!status) {
680+
return status;
681+
}
682+
next = post_processing(pos_tensor, is_prompt);
683+
return base::error::Success();
684+
}
685+
677686
void LLama2Model::attention_mha(int32_t layer_idx, const tensor::Tensor& pos_tensor) const {
678687
CHECK(llama_layers_ != nullptr);
679688
// mha

kuiper/source/op/kernels/cuda/rmsnorm_kernel.cu

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,12 @@ void rmsnorm_kernel_cu(const tensor::Tensor& input, const tensor::Tensor& weight
6464
float* in_ptr = const_cast<float*>(input.ptr<float>());
6565
float* wei_ptr = const_cast<float*>(weight.ptr<float>());
6666
float* out_ptr = const_cast<float*>(output.ptr<float>());
67-
if (size < 1024) {
68-
constexpr int threads_num = 128;
69-
if (stream) {
70-
cudaStream_t stream_ = static_cast<cudaStream_t>(stream);
71-
row_rmsnorm_f32<128><<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
72-
} else {
73-
row_rmsnorm_f32<128><<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
74-
}
67+
constexpr int threads_num = 128;
68+
if (stream) {
69+
cudaStream_t stream_ = static_cast<cudaStream_t>(stream);
70+
row_rmsnorm_f32<128><<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
7571
} else {
76-
constexpr int threads_num = 1024;
77-
if (stream) {
78-
cudaStream_t stream_ = static_cast<cudaStream_t>(stream);
79-
row_rmsnorm_f32<1024><<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
80-
} else {
81-
row_rmsnorm_f32<1024><<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
82-
}
72+
row_rmsnorm_f32<128><<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
8373
}
8474
}
8575
} // namespace kernel

kuiper/source/op/kernels/cuda/rope_kernel.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ __global__ void rope_kernel_cu_fp32(int pos, int dim, int kv_dim, int head_size,
2121
float val = static_cast<float>(pos) * freq;
2222
float fcr = cosf(val);
2323
float fci = sinf(val);
24-
bool is_greater = idx >= kv_dim;
25-
26-
return rope_calc(fcr, fci, const_cast<float*>(input_q), idx) ;
24+
rope_calc(fcr, fci, const_cast<float*>(input_q), idx);
25+
if (idx >= kv_dim) {
26+
return;
27+
}
2728
rope_calc(fcr, fci, const_cast<float*>(input_k), idx);
2829
}
2930

readme.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
# 自制大模型推理框架
22
> 带你从零写一个支持LLama推理,支持Cuda加速的大模型框架
33
4-
**🙋🙋🙋 自制大模型推理框架火热进行中,只要178块,请加下方微信了解**
4+
**🙋🙋🙋 自制大模型推理框架火热进行中,请加下方微信了解**
5+
6+
57

68
<img src="./imgs/me.jpg" alt="me" height="360px" width="300px" />
79

10+
11+
812
## 项目运行效果
913
> LLama1.1b fp32模型,视频无加速,运行平台为Nvidia 3060 laptop,速度为60.34 token/s
14+
1015
![](./imgs/do.gif)
1116

1217
## 课程目录
13-
只要178块,只要178块,只要178块!重要的事情说三遍!!!
18+
1419

1520
**一、项目整体架构和设计**
1621
> 学习架构思维,防止自己只会优化局部实现
@@ -68,6 +73,9 @@
6873
*这里有多个小节*
6974
32. 总结
7075

76+
## 课程常见问题
77+
78+
<img src="./imgs/qa.jpg" style="zoom: 67%;" />
7179

7280
## 第三方依赖
7381
1. google glog https://github.com/google/glog

test/test_op/test_load.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <cuda_runtime_api.h>
2+
#include <fcntl.h>
3+
#include <glog/logging.h>
4+
#include <gtest/gtest.h>
5+
#include <model/config.h>
6+
#include <sys/mman.h>
7+
#include "../source/op/kernels/kernels_interface.h"
8+
#include "base/buffer.h"
9+
10+
TEST(test_load, load_model_config) {
11+
std::string model_path = "./tmp/test.bin";
12+
int32_t fd = open(model_path.data(), O_RDONLY);
13+
ASSERT_NE(fd, -1);
14+
15+
FILE* file = fopen(model_path.data(), "rb");
16+
ASSERT_NE(file, nullptr);
17+
18+
auto config = model::ModelConfig{};
19+
fread(&config, sizeof(model::ModelConfig), 1, file);
20+
ASSERT_EQ(config.dim, 16);
21+
ASSERT_EQ(config.hidden_dim, 128);
22+
ASSERT_EQ(config.layer_num, 256);
23+
}
24+
25+
TEST(test_load, load_model_weight) {
26+
std::string model_path = "./tmp/test.bin";
27+
int32_t fd = open(model_path.data(), O_RDONLY);
28+
ASSERT_NE(fd, -1);
29+
30+
FILE* file = fopen(model_path.data(), "rb");
31+
ASSERT_NE(file, nullptr);
32+
33+
auto config = model::ModelConfig{};
34+
fread(&config, sizeof(model::ModelConfig), 1, file);
35+
36+
fseek(file, 0, SEEK_END);
37+
auto file_size = ftell(file);
38+
39+
void* data = mmap(nullptr, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
40+
float* weight_data =
41+
reinterpret_cast<float*>(static_cast<int8_t*>(data) + sizeof(model::ModelConfig));
42+
43+
for (int i = 0; i < config.dim * config.hidden_dim; ++i) {
44+
ASSERT_EQ(*(weight_data + i), float(i));
45+
}
46+
}

tmp/test.bin

8.03 KB
Binary file not shown.

0 commit comments

Comments
 (0)