Skip to content

Commit

Permalink
[CUDA] Fix beam search of num_beams > 32 (#23599)
Browse files Browse the repository at this point in the history
### Description
* Pass topk_scores to beam scorer in slow topk path.
* Add an env variable `ORT_BEAM_SEARCH_USE_FAST_TOPK` to enable/disable fast topk.
* Add a test case for slow topk path.

### Motivation and Context

This bug was introduced in
#16272

Beam search uses fast cuda kernel when number of beams <= 32. When beam
size is larger than that threshold, we use another code path (slower
cuda kernel) to get topk. In such `slow topk path`, topk_scores shall be
passed to beam scorer but it is not.

This bug will cause incorrect result when num_beams > 32. It was not
found previously since such large beam size is rarely used.
  • Loading branch information
tianleiwu authored Feb 7, 2025
1 parent 82840f6 commit 09e5724
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
temperature = 1.0f;
}
}

// The following parameter is read from environment variable for testing purpose.
use_fast_topk = ParseEnvironmentVariableWithDefault<bool>(kBeamSearchUseFastTopK, true);
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,14 @@ struct IGenerationParameters {
int extra_decoding_ids_input_id = -1;
int cross_qk_output_id = -1;
int no_speech_probs_output_id = -1;

// Parameter for testing slow topk path. It can be updated by the below environment variable.
bool use_fast_topk = true;
};

// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled).
constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK";

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ Status ProcessLogits(const OrtValue& logits, //
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
}

if (num_beams <= 32) {
gsl::span<float> scores_to_process = beam_state->next_scores;
if (parameters->use_fast_topk && num_beams <= 32) {
constexpr size_t max_parts_of_vocab = 128;
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
float* topk_tmp_buffer = beam_state->topk_buffer.data();
Expand All @@ -546,13 +547,6 @@ Status ProcessLogits(const OrtValue& logits, //
beam_state->next_tokens.data(),
beam_state->next_indices.data(),
cuda_stream);

// Select [batch_size, 2 * num_beams] from [batch_size * num_beams, 2 * num_beams]
#ifdef DEBUG_GENERATION
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, 2 * num_beams);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
} else {
// Apply top-k selection like the following:
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
Expand Down Expand Up @@ -588,18 +582,20 @@ Status ProcessLogits(const OrtValue& logits, //
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);

#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", topk_scores->Data<float>(), batch_size, top_k);
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
scores_to_process = gsl::span<float>(topk_scores->MutableData<float>(), batch_size * top_k);
}

// gsl::span doesn't convert from non const to const, so all we're doing here is making each const.
gsl::span<const float> next_scores(beam_state->next_scores.data(), beam_state->next_scores.size());
gsl::span<const float> next_scores(scores_to_process.data(), scores_to_process.size());
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());

#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", next_scores.data(), batch_size, 2 * num_beams);
dumper->Print("next_tokens before scorer", next_tokens.data(), batch_size, 2 * num_beams);
dumper->Print("next_indices before scorer", next_indices.data(), batch_size, 2 * num_beams);
#endif

beam_scorer->Process(
*sequences,
next_scores,
Expand Down Expand Up @@ -735,6 +731,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
next_tokens,
next_indices,
stream_);

CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_));

cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/model_tester.h"
#include "test/util/include/current_test_name.h"
#include "test/util/include/scoped_env_vars.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
Expand All @@ -19,7 +21,7 @@ extern std::unique_ptr<Ort::Env> ort_env;
namespace onnxruntime {
namespace test {

TEST(BeamSearchTest, GptBeamSearchFp32) {
void RunGptBeamSearchFp32() {
std::vector<int64_t> input_ids_shape{3, 12};
std::vector<int32_t> input_ids{
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
Expand Down Expand Up @@ -107,6 +109,16 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
}

TEST(BeamSearchTest, GptBeamSearchFp32) {
RunGptBeamSearchFp32();
}

TEST(BeamSearchTest, GptBeamSearchFp32_DisableFastTopK) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{{onnxruntime::contrib::transformers::kBeamSearchUseFastTopK, "0"}}};
RunGptBeamSearchFp32();
}

TEST(BeamSearchTest, GptBeamSearchFp16) {
std::vector<int64_t> input_ids_shape{3, 12};
std::vector<int32_t> input_ids{
Expand Down

0 comments on commit 09e5724

Please sign in to comment.