Skip to content

Commit 3f76969

Browse files
authored
feat: Softmax free sampling (#1035)
1 parent 08440a8 commit 3f76969

File tree

9 files changed

+380
-1
lines changed

9 files changed

+380
-1
lines changed

benchmarks/bench_sampling.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ def init_seed_sampling(*args, **kwargs):
2727
return flashinfer.sampling.sampling_from_probs(*args, **kwargs)
2828

2929

30+
def init_seed_sampling_from_logits(*args, **kwargs):
31+
torch.manual_seed(42)
32+
return flashinfer.sampling.sampling_from_logits(*args, **kwargs)
33+
34+
35+
def init_seed_sampling_from_softmax_logits(logits, *args, **kwargs):
36+
torch.manual_seed(42)
37+
return flashinfer.sampling.sampling_from_probs(
38+
torch.softmax(logits, dim=-1), *args, **kwargs
39+
)
40+
41+
3042
def init_seed_top_k_sampling(*args, **kwargs):
3143
torch.manual_seed(42)
3244
return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs)
@@ -139,6 +151,69 @@ def main():
139151
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
140152
)
141153

154+
print("---")
155+
print("sampling from softmax(logits)")
156+
for vocab_size in [128512]:
157+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
158+
for distrib in [
159+
normal_distribution(1),
160+
normal_distribution(5),
161+
gumbel_distribution(0.1),
162+
gumbel_distribution(1),
163+
]:
164+
for deterministic in [True, False]:
165+
logits = distrib((batch_size, vocab_size), device="cuda")
166+
samples = torch.zeros(
167+
batch_size, dtype=torch.int32, device=logits.device
168+
)
169+
ms = do_bench(
170+
lambda: init_seed_sampling_from_softmax_logits(
171+
logits, samples, deterministic=deterministic
172+
),
173+
warmup=100,
174+
rep=1000,
175+
)
176+
io = (
177+
logits.numel() * logits.element_size()
178+
+ samples.numel() * samples.element_size()
179+
)
180+
bandwidth = io * 1e-6 / ms
181+
print(
182+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
183+
)
184+
185+
print("---")
186+
print("sampling from logits")
187+
for vocab_size in [128512]:
188+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
189+
for distrib in [
190+
normal_distribution(1),
191+
normal_distribution(5),
192+
gumbel_distribution(0.1),
193+
gumbel_distribution(1),
194+
]:
195+
for deterministic in [True, False]:
196+
logits = distrib((batch_size, vocab_size), device="cuda")
197+
samples = torch.zeros(
198+
batch_size, dtype=torch.int32, device=logits.device
199+
)
200+
ms = do_bench(
201+
lambda: init_seed_sampling_from_logits(
202+
logits, samples, deterministic=deterministic
203+
),
204+
warmup=100,
205+
rep=1000,
206+
)
207+
208+
io = (
209+
logits.numel() * logits.element_size()
210+
+ samples.numel() * samples.element_size()
211+
)
212+
bandwidth = io * 1e-6 / ms
213+
print(
214+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
215+
)
216+
142217

143218
if __name__ == "__main__":
144219
main()

csrc/flashinfer_ops.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ void sampling_from_probs(at::Tensor probs, at::Tensor output,
176176
std::optional<at::Tensor> maybe_indices, bool deterministic,
177177
std::optional<at::Generator> gen);
178178

179+
void sampling_from_logits(at::Tensor logits, at::Tensor output,
180+
std::optional<at::Tensor> maybe_indices, bool deterministic,
181+
std::optional<at::Generator> gen);
182+
179183
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
180184
std::optional<at::Tensor> maybe_indices,
181185
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
@@ -294,6 +298,8 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
294298
// sampling
295299
// Sample from probabilities
296300
m.def("sampling_from_probs", sampling_from_probs);
301+
// Sample from logits
302+
m.def("sampling_from_logits", sampling_from_logits);
297303
// Top-k sampling from probabilities
298304
m.def("top_k_sampling_from_probs", top_k_sampling_from_probs);
299305
// Min-p sampling from probabilities

csrc/flashinfer_sampling_ops.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ void sampling_from_probs(at::Tensor probs, at::Tensor output,
1919
std::optional<at::Tensor> maybe_indices, bool deterministic,
2020
std::optional<at::Generator> gen);
2121

22+
void sampling_from_logits(at::Tensor logits, at::Tensor output,
23+
std::optional<at::Tensor> maybe_indices, bool deterministic,
24+
std::optional<at::Generator> gen);
25+
2226
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
2327
std::optional<at::Tensor> maybe_indices,
2428
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
@@ -58,6 +62,8 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
5862
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
5963
// Sample from probabilities
6064
m.def("sampling_from_probs", sampling_from_probs);
65+
// Sample from logits
66+
m.def("sampling_from_logits", sampling_from_logits);
6167
// Top-k sampling from probabilities
6268
m.def("top_k_sampling_from_probs", top_k_sampling_from_probs);
6369
// Min-p sampling from probabilities

csrc/sampling.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,33 @@
2525

2626
using namespace flashinfer;
2727

28+
void sampling_from_logits(at::Tensor logits, at::Tensor output,
29+
std::optional<at::Tensor> maybe_indices, bool deterministic,
30+
std::optional<at::Generator> gen_) {
31+
CHECK_INPUT(logits);
32+
auto device = logits.device();
33+
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
34+
unsigned int batch_size = output.size(0);
35+
unsigned int vocab_size = logits.size(1);
36+
37+
uint64_t philox_seed, philox_offset;
38+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
39+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
40+
std::lock_guard<std::mutex> lock(gen->mutex_);
41+
at::PhiloxCudaState rng_engine_inputs = gen->philox_cuda_state(batch_size * vocab_size);
42+
philox_seed = rng_engine_inputs.seed_.val;
43+
philox_offset = rng_engine_inputs.offset_.val;
44+
45+
const c10::cuda::OptionalCUDAGuard device_guard(device);
46+
auto stream = at::cuda::getCurrentCUDAStream();
47+
cudaError_t status = sampling::SamplingFromLogits(
48+
static_cast<float*>(logits.data_ptr()), static_cast<int*>(output.data_ptr()),
49+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
50+
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
51+
TORCH_CHECK(status == cudaSuccess, "SamplingFromLogits failed with error code " +
52+
std::string(cudaGetErrorString(status)));
53+
}
54+
2855
void sampling_from_probs(at::Tensor probs, at::Tensor output,
2956
std::optional<at::Tensor> maybe_indices, bool deterministic,
3057
std::optional<at::Generator> gen_) {

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
)
8181
from .sampling import chain_speculative_sampling as chain_speculative_sampling
8282
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
83+
from .sampling import sampling_from_logits as sampling_from_logits
8384
from .sampling import sampling_from_probs as sampling_from_probs
8485
from .sampling import top_k_mask_logits as top_k_mask_logits
8586
from .sampling import top_k_renorm_probs as top_k_renorm_probs

flashinfer/sampling.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ def get_sampling_module():
4242
],
4343
)
4444

45+
# torch library for sampling_from_logits
46+
@register_custom_op("flashinfer::sampling_from_logits", mutates_args=())
47+
def sampling_from_logits(
48+
logits: torch.Tensor,
49+
indices: Optional[torch.Tensor],
50+
deterministic: bool,
51+
generator: Optional[torch.Generator],
52+
) -> torch.Tensor:
53+
device = logits.device
54+
# TODO: support more data types in logits to avoid conversion
55+
# to float32
56+
logits = logits.float()
57+
batch_size = indices.size(0) if indices is not None else logits.size(0)
58+
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
59+
module.sampling_from_logits.default(
60+
logits,
61+
samples,
62+
indices,
63+
deterministic,
64+
generator,
65+
)
66+
return samples
67+
68+
@register_fake_op("flashinfer::sampling_from_logits")
69+
def _fake_sampling_from_logits(
70+
logits: torch.Tensor,
71+
indices: Optional[torch.Tensor],
72+
deterministic: bool,
73+
generator: Optional[torch.Generator],
74+
) -> torch.Tensor:
75+
batch_size = indices.size(0) if indices is not None else logits.size(0)
76+
return torch.empty(batch_size, dtype=torch.int32, device=logits.device)
77+
4578
# torch library for sampling_from_probs
4679

4780
@register_custom_op("flashinfer::sampling_from_probs", mutates_args=())
@@ -64,6 +97,8 @@ def sampling_from_probs(
6497
)
6598
return samples
6699

100+
# torch library for sampling_from_probs
101+
67102
@register_fake_op("flashinfer::sampling_from_probs")
68103
def _fake_sampling_from_probs(
69104
probs: torch.Tensor,
@@ -384,6 +419,7 @@ def _fake_chain_speculative_sampling(
384419
# Register the module
385420
_sampling_module = SimpleNamespace(
386421
sampling_from_probs=sampling_from_probs,
422+
sampling_from_logits=sampling_from_logits,
387423
top_p_sampling_from_probs=top_p_sampling_from_probs,
388424
top_k_sampling_from_probs=top_k_sampling_from_probs,
389425
min_p_sampling_from_probs=min_p_sampling_from_probs,
@@ -404,6 +440,64 @@ def _to_tensor_scalar_tuple(x):
404440
return (None, x)
405441

406442

443+
def sampling_from_logits(
444+
logits: torch.Tensor,
445+
indices: Optional[torch.Tensor] = None,
446+
deterministic: bool = True,
447+
generator: Optional[torch.Generator] = None,
448+
check_nan: bool = False,
449+
) -> torch.Tensor:
450+
r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling
451+
from :attr:`logits` after applying softmax.
452+
Parameters
453+
----------
454+
logits: torch.Tensor
455+
Logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
456+
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
457+
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
458+
probability distributions.
459+
indices: Optional[torch.Tensor]
460+
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in logits.
461+
For example, if indices[i] = j, then the i-th output will be sampled from logits[j].
462+
This allows reusing the same probability distribution for multiple outputs.
463+
If indices is not provided, the i-th output will be sampled from the i-th row of logits.
464+
deterministic: bool
465+
Since the sampling doesn't use cub's BlockScan, the sampling is deterministic. We keep this
466+
argument for compatibility with other sampling functions.
467+
generator: Optional[torch.Generator]
468+
A random number generator for the operation.
469+
check_nan: bool
470+
Whether to check nan in :attr:`logits`, default is ``False``.
471+
Returns
472+
-------
473+
samples: torch.Tensor
474+
Sampled categories, shape (batch_size,). It's equivalent to sampling from
475+
:attr:`logits` after applying softmax.
476+
Examples
477+
--------
478+
>>> import torch
479+
>>> import flashinfer
480+
>>> torch.manual_seed(42)
481+
>>> batch_size = 4
482+
>>> vocab_size = 5
483+
>>> logits = torch.rand(batch_size, vocab_size).to(0)
484+
>>> logits
485+
tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
486+
[0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
487+
[0.9346, 0.5936, 0.8694, 0.5677, 0.7411],
488+
[0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0')
489+
>>> samples = flashinfer.sampling.sampling_from_logits(logits)
490+
>>> samples
491+
tensor([0, 1, 1, 1], device='cuda:0', dtype=torch.int32)
492+
"""
493+
if check_nan:
494+
if torch.any(torch.isnan(logits)):
495+
raise ValueError("Input logits contains NaN.")
496+
return get_sampling_module().sampling_from_logits(
497+
logits, indices, deterministic, generator
498+
)
499+
500+
407501
def sampling_from_probs(
408502
probs: torch.Tensor,
409503
indices: Optional[torch.Tensor] = None,

0 commit comments

Comments
 (0)