Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 116 additions & 58 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use tracing;

use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding;

use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
Expand Down Expand Up @@ -151,10 +152,108 @@ impl OpenAIPreprocessor {
&self,
request: &R,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new();
let mut builder = self.builder(request)?;
let formatted_prompt = self.apply_template(request)?;
let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;

Ok((builder.build()?, annotations))
}

pub fn builder<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<PreprocessedRequestBuilder> {
let mut builder = PreprocessedRequest::builder();
builder.model(request.model());

let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}

// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();

if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}

builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}

Ok(builder)
}

pub fn apply_template<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<Option<String>> {
if let PromptInput::Text(_) = request.prompt_input_type()
&& let Some(TextInput::Single(_)) = request.extract_text()
{
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));

let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
Ok(Some(formatted_prompt))
} else {
Ok(None)
}
}

pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<HashMap<String, String>> {
let mut annotations = HashMap::new();
// match request type before any conversion/processing
match request.prompt_input_type() {
PromptInput::Tokens(_) => {
Expand All @@ -177,22 +276,16 @@ impl OpenAIPreprocessor {
PromptInput::Text(_) => {
if let Some(text_input) = request.extract_text() {
match text_input {
TextInput::Single(_) => {
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));

let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
TextInput::Single(raw_prompt) => {
if let Some(f) = formatted_prompt.as_ref()
&& request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
{
annotations
.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
}

// Completions will use raw_prompt, no template
let prompt = formatted_prompt.unwrap_or(raw_prompt);

// Check if backend_instance_id is present and token_data is provided
let has_backend_instance_id = request
Expand All @@ -215,22 +308,15 @@ impl OpenAIPreprocessor {
tracing::warn!(
"backend_instance_id provided but no token_data; tokenizing prompt"
);
let encoding = self.tokenizer.encode(&formatted_prompt)?;
let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false)
}
} else {
// No backend_instance_id provided, continue the normal flow.
let encoding = self.tokenizer.encode(&formatted_prompt)?;
let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false)
};

if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert(
ANNOTATION_FORMATTED_PROMPT.to_string(),
formatted_prompt,
);
}

if request.has_annotation(ANNOTATION_TOKEN_IDS)
&& !skip_token_annotation
{
Expand Down Expand Up @@ -258,37 +344,7 @@ impl OpenAIPreprocessor {
}
}
}

let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}

// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();

if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}

builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}

Ok((builder.build()?, annotations))
Ok(annotations)
}

/// Preprocess an embedding request, handling both text and token ID inputs.
Expand Down Expand Up @@ -581,7 +637,9 @@ impl
let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator);
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
let common_request = builder.build()?;

// update isl
response_generator.update_isl(common_request.token_ids.len() as u32);
Expand Down
29 changes: 0 additions & 29 deletions lib/parsers/src/reasoning/base_parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use tracing as log;

use crate::{ParserResult, ReasoningParser};

Expand Down Expand Up @@ -34,13 +33,8 @@ impl BasicReasoningParser {

impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);

let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
log::debug!("in_reasoning: {}", in_reasoning);

if !in_reasoning {
log::debug!("No reasoning detected, returning normal text.");
return ParserResult {
normal_text: text.to_string(),
reasoning_text: String::new(),
Expand All @@ -49,15 +43,8 @@ impl ReasoningParser for BasicReasoningParser {

// The text is considered to be in a reasoning block.
let processed_text = text.replace(&self.think_start_token, "").trim().to_string();
log::debug!(
"Processed text after removing think_start_token: {:?}",
processed_text
);

if !processed_text.contains(&self.think_end_token) {
log::debug!(
"Reasoning truncated, think_end_token not found. Returning reasoning text."
);
// Assume reasoning was truncated before `think_end_token`
return ParserResult {
normal_text: String::new(),
Expand All @@ -73,9 +60,6 @@ impl ReasoningParser for BasicReasoningParser {
.map(|s| s.trim().to_string())
.unwrap_or_default();

log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
log::debug!("Extracted normal_text: {:?}", normal_text);

ParserResult {
normal_text,
reasoning_text,
Expand All @@ -92,19 +76,6 @@ impl ReasoningParser for BasicReasoningParser {
let mut current_text = self._buffer.to_string();
// If the current text is a prefix of the think token, keep buffering

log::debug!(
"parse_reasoning_streaming_incremental called with text: {:?}",
text
);
log::debug!("current buffer: {:?}", self._buffer);
log::debug!("current_text: {:?}", current_text);
log::debug!(
"in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
self._in_reasoning,
self.stripped_think_start,
self.stream_reasoning
);

if self.think_start_token.starts_with(&current_text)
&& self.think_start_token.as_str() != current_text.as_str()
{
Expand Down
6 changes: 3 additions & 3 deletions lib/parsers/src/reasoning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl ReasoningParserType {
}

pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
tracing::debug!("Selected reasoning parser: {}", name);
tracing::debug!(parser_name = name, "Selected reasoning parser");
match name.to_lowercase().as_str() {
"deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
"basic" => Self::Basic.get_reasoning_parser(),
Expand All @@ -156,8 +156,8 @@ impl ReasoningParserType {
"mistral" => Self::Mistral.get_reasoning_parser(),
_ => {
tracing::warn!(
"Unknown reasoning parser type '{}', falling back to Basic Reasoning Parser",
name
parser_name = name,
"Unknown reasoning parser type, falling back to Basic Reasoning Parser",
);
Self::Basic.get_reasoning_parser()
}
Expand Down
Loading