Fix bug for asort kernel & faster sampler with GPU sorting #2730
+106
−73
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR addresses an issue with the
asort
kernel, which can lead to CUDA invalid arguments error when the last dimension to be sorted exceeded the number of threads allowed per CUDA block. With the revisedasort
kernel, we can now partially perform logit sampling on the GPU (specifically, the sorting process). This optimization improves text generation performance by up to 25%, depending on the model and the number of tokens generated.Testing Cases
I conducted tests using the following models and commands:
GLM4-9B
cargo run --release --example glm4 --features cuda -- \ --weight-path /home/glm-4-9b-chat/ \ --prompt "Please talk about deep learning." \ --sample-len 2000
Note: By default, the GLM4 model uses sampling parameters of
temperature=0.8
andtop_p=0.8
.LLaMa3.1 8B
cargo run --release --example llama --features cuda -- \ --weight-path /home/Meta-Llama-3.1-8B-Instruct/ \ --prompt "Please talk about deep learning." \ --temperature 0.7 \ --top-p 0.8 \ --sample-len 2000
Performance Results
Original Implementation (Sampling Entirely on CPU)
GLM4-9B:
LLaMa3.1 8B:
After
asort
Kernel Fix & Faster Sampler with GPU SortGLM4-9B:
LLaMa3.1 8B:
Note: You can adjust the
temperature
parameter to influence the model's behavior and potentially generate more than 2000 tokens for the above comparisons. Whilesample-len
sets the maximum number of tokens that can be generated, the actual output length may vary depending on the model (and sampling process, parameters, etc.).