diff --git a/Cargo.toml b/Cargo.toml index b71b1d9..bced67f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ tracing-subscriber = { version = "0.3", features = [ tracing-futures = { version = "0.2", optional = true } rstructor_derive = { version = "0.2.8", path = "./rstructor_derive", optional = true } chrono = "0.4" # For date/time validation in examples -base64 = { version = "0.22", optional = true } +base64 = "0.22" # Feature flags [features] @@ -68,7 +68,7 @@ default = ["openai", "anthropic", "grok", "gemini", "derive", "logging"] openai = ["reqwest", "tokio"] anthropic = ["reqwest", "tokio"] grok = ["reqwest", "tokio"] -gemini = ["reqwest", "tokio", "base64"] +gemini = ["reqwest", "tokio"] derive = ["rstructor_derive"] logging = ["tracing-subscriber", "tracing-futures"] diff --git a/README.md b/README.md index f813d37..e717b67 100644 --- a/README.md +++ b/README.md @@ -201,29 +201,28 @@ struct Event { ## Multimodal (Image Input) -Analyze images with structured extraction using Gemini's inline data support: +Analyze images with structured extraction across all major providers using `materialize_with_media`: ```rust -use rstructor::{Instructor, LLMClient, GeminiClient, MediaFile}; +use rstructor::{Instructor, LLMClient, OpenAIClient, MediaFile}; #[derive(Instructor, Serialize, Deserialize, Debug)] struct ImageAnalysis { subject: String, - colors: Vec, - is_logo: bool, - description: String, + summary: String, } #[tokio::main] async fn main() -> Result<(), Box> { - // Download or load image bytes + // Download or load image bytes (real-world fixture) let image_bytes = reqwest::get("https://example.com/image.png") .await?.bytes().await?; - // Create inline media from bytes (base64-encoded automatically) + // Inline media is base64-encoded automatically let media = MediaFile::from_bytes(&image_bytes, "image/png"); - let client = GeminiClient::from_env()?; + // Works with OpenAI, Anthropic, Grok, and Gemini clients + let client = OpenAIClient::from_env()?; let analysis: ImageAnalysis = client .materialize_with_media("Describe this image", &[media]) .await?; @@ -232,7 +231,13 @@ async fn main() -> Result<(), Box> { } ``` -`MediaFile::new(uri, mime_type)` is also available for Gemini Files API / GCS URIs. +`MediaFile::new(uri, mime_type)` is also available for URL/URI-based media input. + +Provider examples: +- `cargo run --example openai_multimodal_example --features openai` +- `cargo run --example anthropic_multimodal_example --features anthropic` +- `cargo run --example grok_multimodal_example --features grok` +- `cargo run --example gemini_multimodal_example --features gemini` ## Extended Thinking diff --git a/examples/anthropic_multimodal_example.rs b/examples/anthropic_multimodal_example.rs new file mode 100644 index 0000000..5cce9d8 --- /dev/null +++ b/examples/anthropic_multimodal_example.rs @@ -0,0 +1,38 @@ +//! Anthropic Multimodal Structured Extraction Example +//! +//! Run with: +//! ```bash +//! export ANTHROPIC_API_KEY=your_key_here +//! cargo run --example anthropic_multimodal_example --features anthropic +//! ``` + +use rstructor::{AnthropicClient, AnthropicModel, Instructor, LLMClient, MediaFile}; +use serde::{Deserialize, Serialize}; +use std::env; + +#[derive(Instructor, Serialize, Deserialize, Debug)] +struct ImageAnalysis { + subject: String, + summary: String, + colors: Vec, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env::var("ANTHROPIC_API_KEY").expect("Please set ANTHROPIC_API_KEY environment variable"); + + let image_url = "https://www.rust-lang.org/logos/rust-logo-512x512.png"; + let image_bytes = reqwest::get(image_url).await?.bytes().await?; + let media = MediaFile::from_bytes(&image_bytes, "image/png"); + + let client = AnthropicClient::from_env()? + .model(AnthropicModel::ClaudeOpus46) + .temperature(0.0); + + let analysis: ImageAnalysis = client + .materialize_with_media("Describe this image and list dominant colors.", &[media]) + .await?; + + println!("{:#?}", analysis); + Ok(()) +} diff --git a/examples/grok_multimodal_example.rs b/examples/grok_multimodal_example.rs new file mode 100644 index 0000000..244ed0b --- /dev/null +++ b/examples/grok_multimodal_example.rs @@ -0,0 +1,38 @@ +//! Grok Multimodal Structured Extraction Example +//! +//! Run with: +//! ```bash +//! export XAI_API_KEY=your_key_here +//! cargo run --example grok_multimodal_example --features grok +//! ``` + +use rstructor::{GrokClient, GrokModel, Instructor, LLMClient, MediaFile}; +use serde::{Deserialize, Serialize}; +use std::env; + +#[derive(Instructor, Serialize, Deserialize, Debug)] +struct ImageAnalysis { + subject: String, + summary: String, + colors: Vec, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env::var("XAI_API_KEY").expect("Please set XAI_API_KEY environment variable"); + + let image_url = "https://www.rust-lang.org/logos/rust-logo-512x512.png"; + let image_bytes = reqwest::get(image_url).await?.bytes().await?; + let media = MediaFile::from_bytes(&image_bytes, "image/png"); + + let client = GrokClient::from_env()? + .model(GrokModel::Grok41FastNonReasoning) + .temperature(0.0); + + let analysis: ImageAnalysis = client + .materialize_with_media("Describe this image and list dominant colors.", &[media]) + .await?; + + println!("{:#?}", analysis); + Ok(()) +} diff --git a/examples/openai_multimodal_example.rs b/examples/openai_multimodal_example.rs new file mode 100644 index 0000000..2d7737b --- /dev/null +++ b/examples/openai_multimodal_example.rs @@ -0,0 +1,38 @@ +//! OpenAI Multimodal Structured Extraction Example +//! +//! Run with: +//! ```bash +//! export OPENAI_API_KEY=your_key_here +//! cargo run --example openai_multimodal_example --features openai +//! ``` + +use rstructor::{Instructor, LLMClient, MediaFile, OpenAIClient, OpenAIModel}; +use serde::{Deserialize, Serialize}; +use std::env; + +#[derive(Instructor, Serialize, Deserialize, Debug)] +struct ImageAnalysis { + subject: String, + summary: String, + colors: Vec, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env::var("OPENAI_API_KEY").expect("Please set OPENAI_API_KEY environment variable"); + + let image_url = "https://www.rust-lang.org/logos/rust-logo-512x512.png"; + let image_bytes = reqwest::get(image_url).await?.bytes().await?; + let media = MediaFile::from_bytes(&image_bytes, "image/png"); + + let client = OpenAIClient::from_env()? + .model(OpenAIModel::Gpt52) + .temperature(0.0); + + let analysis: ImageAnalysis = client + .materialize_with_media("Describe this image and list dominant colors.", &[media]) + .await?; + + println!("{:#?}", analysis); + Ok(()) +} diff --git a/src/backend/anthropic.rs b/src/backend/anthropic.rs index aa66d81..c54cc2d 100644 --- a/src/backend/anthropic.rs +++ b/src/backend/anthropic.rs @@ -7,9 +7,10 @@ use std::time::Duration; use tracing::{debug, error, info, instrument, trace, warn}; use crate::backend::{ - ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, MaterializeResult, - ModelInfo, ThinkingLevel, TokenUsage, ValidationFailureContext, check_response_status, - generate_with_retry_with_history, handle_http_error, parse_validate_and_create_output, + AnthropicMessageContent, ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, + MaterializeResult, ModelInfo, ThinkingLevel, TokenUsage, ValidationFailureContext, + build_anthropic_message_content, check_response_status, generate_with_retry_with_history, + handle_http_error, materialize_with_media_with_retry, parse_validate_and_create_output, prepare_strict_schema, }; use crate::error::{ApiErrorKind, RStructorError, Result}; @@ -153,7 +154,7 @@ pub struct AnthropicClient { #[derive(Debug, Serialize)] struct AnthropicMessage { role: String, - content: String, + content: AnthropicMessageContent, } /// Output format for structured outputs (native Anthropic structured outputs) @@ -351,11 +352,14 @@ impl AnthropicClient { // With native structured outputs, we don't need to include schema instructions in the prompt let api_messages: Vec = messages .iter() - .map(|msg| AnthropicMessage { - role: msg.role.as_str().to_string(), - content: msg.content.clone(), + .map(|msg| { + Ok(AnthropicMessage { + role: msg.role.as_str().to_string(), + content: build_anthropic_message_content(msg)?, + }) }) - .collect(); + .collect::>>() + .map_err(|e| (e, None))?; // Build thinking config for Claude 4.x models let is_thinking_model = self.config.model.as_str().contains("sonnet-4") @@ -574,6 +578,32 @@ impl LLMClient for AnthropicClient { Ok(output.data) } + #[instrument( + name = "anthropic_materialize_with_media", + skip(self, prompt, media), + fields( + type_name = std::any::type_name::(), + model = %self.config.model.as_str(), + prompt_len = prompt.len(), + media_len = media.len() + ) + )] + async fn materialize_with_media(&self, prompt: &str, media: &[super::MediaFile]) -> Result + where + T: Instructor + DeserializeOwned + Send + 'static, + { + materialize_with_media_with_retry( + |messages: Vec| { + let this = self; + async move { this.materialize_internal::(&messages).await } + }, + prompt, + media, + self.config.max_retries, + ) + .await + } + #[instrument( name = "anthropic_materialize_with_metadata", skip(self, prompt), @@ -650,7 +680,7 @@ impl LLMClient for AnthropicClient { model: self.config.model.as_str().to_string(), messages: vec![AnthropicMessage { role: "user".to_string(), - content: prompt.to_string(), + content: AnthropicMessageContent::Text(prompt.to_string()), }], temperature: effective_temp, max_tokens: effective_max_tokens(self.config.max_tokens, thinking_config.as_ref()), diff --git a/src/backend/gemini.rs b/src/backend/gemini.rs index fd8bf5d..96e3276 100644 --- a/src/backend/gemini.rs +++ b/src/backend/gemini.rs @@ -9,7 +9,8 @@ use tracing::{debug, error, info, instrument, trace, warn}; use crate::backend::{ ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, MaterializeResult, ModelInfo, ThinkingLevel, TokenUsage, ValidationFailureContext, check_response_status, - generate_with_retry_with_history, handle_http_error, parse_validate_and_create_output, + generate_with_retry_with_history, handle_http_error, materialize_with_media_with_retry, + parse_validate_and_create_output, }; use crate::error::{ApiErrorKind, RStructorError, Result}; use crate::model::Instructor; @@ -50,12 +51,16 @@ pub enum Model { Gemini25Flash, /// Gemini 2.5 Flash Lite (smaller, faster variant) Gemini25FlashLite, + /// Gemini 2.5 Flash Image (image generation/analysis tuned variant) + Gemini25FlashImage, /// Gemini 2.0 Flash (stable 2.0 Flash model) Gemini20Flash, /// Gemini 2.0 Flash 001 (specific version of 2.0 Flash) Gemini20Flash001, /// Gemini 2.0 Flash Lite (smaller 2.0 Flash variant) Gemini20FlashLite, + /// Gemini 2.0 Flash Lite 001 (specific version of 2.0 Flash Lite) + Gemini20FlashLite001, /// Gemini Pro Latest (alias for latest Pro model) GeminiProLatest, /// Gemini Flash Latest (alias for latest Flash model) @@ -74,9 +79,11 @@ impl Model { Model::Gemini25Pro => "gemini-2.5-pro", Model::Gemini25Flash => "gemini-2.5-flash", Model::Gemini25FlashLite => "gemini-2.5-flash-lite", + Model::Gemini25FlashImage => "gemini-2.5-flash-image", Model::Gemini20Flash => "gemini-2.0-flash", Model::Gemini20Flash001 => "gemini-2.0-flash-001", Model::Gemini20FlashLite => "gemini-2.0-flash-lite", + Model::Gemini20FlashLite001 => "gemini-2.0-flash-lite-001", Model::GeminiProLatest => "gemini-pro-latest", Model::GeminiFlashLatest => "gemini-flash-latest", Model::GeminiFlashLiteLatest => "gemini-flash-lite-latest", @@ -96,9 +103,11 @@ impl Model { "gemini-2.5-pro" => Model::Gemini25Pro, "gemini-2.5-flash" => Model::Gemini25Flash, "gemini-2.5-flash-lite" => Model::Gemini25FlashLite, + "gemini-2.5-flash-image" => Model::Gemini25FlashImage, "gemini-2.0-flash" => Model::Gemini20Flash, "gemini-2.0-flash-001" => Model::Gemini20Flash001, "gemini-2.0-flash-lite" => Model::Gemini20FlashLite, + "gemini-2.0-flash-lite-001" => Model::Gemini20FlashLite001, "gemini-pro-latest" => Model::GeminiProLatest, "gemini-flash-latest" => Model::GeminiFlashLatest, "gemini-flash-lite-latest" => Model::GeminiFlashLiteLatest, @@ -652,14 +661,16 @@ impl LLMClient for GeminiClient { where T: Instructor + DeserializeOwned + Send + 'static, { - // For media support, we need to create a ChatMessage with media and pass it directly - // We can't use generate_with_retry_with_history since it only takes a string prompt - let initial_message = ChatMessage::user_with_media(prompt, media.to_vec()); - let output = self - .materialize_internal::(&[initial_message]) - .await - .map_err(|(err, _)| err)?; - Ok(output.data) + materialize_with_media_with_retry( + |messages: Vec| { + let this = self; + async move { this.materialize_internal::(&messages).await } + }, + prompt, + media, + self.config.max_retries, + ) + .await } #[instrument( diff --git a/src/backend/grok.rs b/src/backend/grok.rs index 21bc36c..dbd355e 100644 --- a/src/backend/grok.rs +++ b/src/backend/grok.rs @@ -1,15 +1,16 @@ use async_trait::async_trait; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; use std::str::FromStr; use std::time::Duration; use tracing::{debug, error, info, instrument, trace, warn}; use crate::backend::{ ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, MaterializeResult, - ModelInfo, ResponseFormat, TokenUsage, ValidationFailureContext, check_response_status, - generate_with_retry_with_history, handle_http_error, parse_validate_and_create_output, - prepare_strict_schema, + ModelInfo, OpenAICompatibleChatCompletionRequest, OpenAICompatibleChatCompletionResponse, + OpenAICompatibleChatMessage, OpenAICompatibleMessageContent, ResponseFormat, TokenUsage, + ValidationFailureContext, check_response_status, convert_openai_compatible_chat_messages, + generate_with_retry_with_history, handle_http_error, materialize_with_media_with_retry, + parse_validate_and_create_output, prepare_strict_schema, }; use crate::error::{ApiErrorKind, RStructorError, Result}; use crate::model::Instructor; @@ -138,56 +139,7 @@ pub struct GrokClient { client: reqwest::Client, } -// Grok API request and response structures (OpenAI-compatible) -#[derive(Debug, Serialize)] -struct GrokChatMessage { - role: String, - content: String, -} - -// ResponseFormat and JsonSchemaFormat are now imported from utils - -#[derive(Debug, Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, - temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct ResponseMessage { - role: String, - content: Option, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct ChatCompletionChoice { - message: ResponseMessage, - finish_reason: String, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct UsageInfo { - prompt_tokens: u64, - completion_tokens: u64, - #[serde(default)] - total_tokens: u64, -} - -#[derive(Debug, Deserialize)] -struct ChatCompletionResponse { - choices: Vec, - #[serde(default)] - usage: Option, - model: Option, -} +// Grok uses shared OpenAI-compatible chat completion request/response types. impl GrokClient { /// Create a new Grok client with the provided API key. @@ -310,13 +262,8 @@ impl GrokClient { // Build API messages from conversation history // With native structured outputs, we don't need to include schema instructions in the prompt - let api_messages: Vec = messages - .iter() - .map(|msg| GrokChatMessage { - role: msg.role.as_str().to_string(), - content: msg.content.clone(), - }) - .collect(); + let api_messages = + convert_openai_compatible_chat_messages(messages, "Grok").map_err(|e| (e, None))?; // Create response format for native structured outputs let response_format = ResponseFormat::json_schema(schema_name.clone(), schema_json, None); @@ -325,12 +272,13 @@ impl GrokClient { "Building Grok API request with structured outputs (history_len={})", api_messages.len() ); - let request = ChatCompletionRequest { + let request = OpenAICompatibleChatCompletionRequest { model: self.config.model.as_str().to_string(), messages: api_messages, response_format: Some(response_format), temperature: self.config.temperature, max_tokens: self.config.max_tokens, + reasoning_effort: None, }; let base_url = self @@ -355,10 +303,11 @@ impl GrokClient { .map_err(|e| (e, None))?; debug!("Successfully received response from Grok API"); - let completion: ChatCompletionResponse = response.json().await.map_err(|e| { - error!(error = %e, "Failed to parse JSON response from Grok API"); - (RStructorError::from(e), None) - })?; + let completion: OpenAICompatibleChatCompletionResponse = + response.json().await.map_err(|e| { + error!(error = %e, "Failed to parse JSON response from Grok API"); + (RStructorError::from(e), None) + })?; if completion.choices.is_empty() { error!("Grok API returned empty choices array"); @@ -470,6 +419,32 @@ impl LLMClient for GrokClient { Ok(output.data) } + #[instrument( + name = "grok_materialize_with_media", + skip(self, prompt, media), + fields( + type_name = std::any::type_name::(), + model = %self.config.model.as_str(), + prompt_len = prompt.len(), + media_len = media.len() + ) + )] + async fn materialize_with_media(&self, prompt: &str, media: &[super::MediaFile]) -> Result + where + T: Instructor + DeserializeOwned + Send + 'static, + { + materialize_with_media_with_retry( + |messages: Vec| { + let this = self; + async move { this.materialize_internal::(&messages).await } + }, + prompt, + media, + self.config.max_retries, + ) + .await + } + #[instrument( name = "grok_materialize_with_metadata", skip(self, prompt), @@ -521,15 +496,16 @@ impl LLMClient for GrokClient { // Build the request without structured outputs debug!("Building Grok API request for text generation"); - let request = ChatCompletionRequest { + let request = OpenAICompatibleChatCompletionRequest { model: self.config.model.as_str().to_string(), - messages: vec![GrokChatMessage { + messages: vec![OpenAICompatibleChatMessage { role: "user".to_string(), - content: prompt.to_string(), + content: OpenAICompatibleMessageContent::Text(prompt.to_string()), }], response_format: None, temperature: self.config.temperature, max_tokens: self.config.max_tokens, + reasoning_effort: None, }; // Send the request to Grok/xAI API @@ -554,10 +530,11 @@ impl LLMClient for GrokClient { let response = check_response_status(response, "Grok").await?; debug!("Successfully received response from Grok API"); - let completion: ChatCompletionResponse = response.json().await.map_err(|e| { - error!(error = %e, "Failed to parse JSON response from Grok API"); - e - })?; + let completion: OpenAICompatibleChatCompletionResponse = + response.json().await.map_err(|e| { + error!(error = %e, "Failed to parse JSON response from Grok API"); + e + })?; if completion.choices.is_empty() { error!("Grok API returned empty choices array"); diff --git a/src/backend/media.rs b/src/backend/media.rs new file mode 100644 index 0000000..dba0ad0 --- /dev/null +++ b/src/backend/media.rs @@ -0,0 +1,214 @@ +use serde::Serialize; + +use crate::backend::ChatMessage; +use crate::error::{ApiErrorKind, RStructorError, Result}; + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub(crate) enum OpenAICompatibleMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum OpenAICompatibleMessagePart { + Text { text: String }, + ImageUrl { image_url: OpenAICompatibleImageUrl }, +} + +#[derive(Debug, Serialize)] +pub(crate) struct OpenAICompatibleImageUrl { + pub(crate) url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) detail: Option, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub(crate) enum AnthropicMessageContent { + Text(String), + Blocks(Vec), +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum AnthropicContentBlock { + Text { text: String }, + Image { source: AnthropicImageSource }, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum AnthropicImageSource { + Base64 { media_type: String, data: String }, + Url { url: String }, +} + +pub(crate) fn build_openai_compatible_message_content( + msg: &ChatMessage, + provider_name: &str, +) -> Result { + if msg.media.is_empty() { + return Ok(OpenAICompatibleMessageContent::Text(msg.content.clone())); + } + + let mut parts = Vec::new(); + if !msg.content.is_empty() { + parts.push(OpenAICompatibleMessagePart::Text { + text: msg.content.clone(), + }); + } + + for media in &msg.media { + let url = media_to_url(media, provider_name)?; + parts.push(OpenAICompatibleMessagePart::ImageUrl { + image_url: OpenAICompatibleImageUrl { + url, + detail: Some("auto".to_string()), + }, + }); + } + + Ok(OpenAICompatibleMessageContent::Parts(parts)) +} + +pub(crate) fn build_anthropic_message_content( + msg: &ChatMessage, +) -> Result { + if msg.media.is_empty() { + return Ok(AnthropicMessageContent::Text(msg.content.clone())); + } + + let mut blocks = Vec::new(); + if !msg.content.is_empty() { + blocks.push(AnthropicContentBlock::Text { + text: msg.content.clone(), + }); + } + + for media in &msg.media { + if let Some(data) = media.data.as_ref() { + if data.is_empty() { + return Err(RStructorError::api_error( + "Anthropic", + ApiErrorKind::BadRequest { + details: "MediaFile inline data cannot be empty".to_string(), + }, + )); + } + if media.mime_type.is_empty() { + return Err(RStructorError::api_error( + "Anthropic", + ApiErrorKind::BadRequest { + details: "MediaFile mime_type cannot be empty".to_string(), + }, + )); + } + blocks.push(AnthropicContentBlock::Image { + source: AnthropicImageSource::Base64 { + media_type: media.mime_type.clone(), + data: data.clone(), + }, + }); + } else if !media.uri.is_empty() { + blocks.push(AnthropicContentBlock::Image { + source: AnthropicImageSource::Url { + url: media.uri.clone(), + }, + }); + } else { + return Err(RStructorError::api_error( + "Anthropic", + ApiErrorKind::BadRequest { + details: "MediaFile must include either inline data or uri".to_string(), + }, + )); + } + } + + Ok(AnthropicMessageContent::Blocks(blocks)) +} + +fn media_to_url(media: &crate::backend::client::MediaFile, provider_name: &str) -> Result { + if let Some(data) = media.data.as_ref() { + if data.is_empty() { + return Err(RStructorError::api_error( + provider_name, + ApiErrorKind::BadRequest { + details: "MediaFile inline data cannot be empty".to_string(), + }, + )); + } + if media.mime_type.is_empty() { + return Err(RStructorError::api_error( + provider_name, + ApiErrorKind::BadRequest { + details: "MediaFile mime_type cannot be empty".to_string(), + }, + )); + } + Ok(format!("data:{};base64,{}", media.mime_type, data)) + } else if !media.uri.is_empty() { + Ok(media.uri.clone()) + } else { + Err(RStructorError::api_error( + provider_name, + ApiErrorKind::BadRequest { + details: "MediaFile must include either inline data or uri".to_string(), + }, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::MediaFile; + + #[test] + fn test_openai_compatible_content_text_only() { + let msg = ChatMessage::user("hello"); + let content = + build_openai_compatible_message_content(&msg, "OpenAI").expect("content should build"); + let json = serde_json::to_value(&content).expect("content should serialize"); + assert_eq!(json, serde_json::json!("hello")); + } + + #[test] + fn test_openai_compatible_content_with_media() { + let msg = ChatMessage::user_with_media( + "describe image", + vec![MediaFile::from_bytes(b"abc", "image/png")], + ); + let content = + build_openai_compatible_message_content(&msg, "OpenAI").expect("content should build"); + let json = serde_json::to_value(&content).expect("content should serialize"); + assert_eq!(json[0]["type"], "text"); + assert_eq!(json[1]["type"], "image_url"); + assert_eq!(json[1]["image_url"]["url"], "data:image/png;base64,YWJj"); + } + + #[test] + fn test_anthropic_content_text_only() { + let msg = ChatMessage::user("hello"); + let content = build_anthropic_message_content(&msg).expect("content should build"); + let json = serde_json::to_value(&content).expect("content should serialize"); + assert_eq!(json, serde_json::json!("hello")); + } + + #[test] + fn test_anthropic_content_with_inline_media() { + let msg = ChatMessage::user_with_media( + "describe image", + vec![MediaFile::from_bytes(b"abc", "image/png")], + ); + let content = build_anthropic_message_content(&msg).expect("content should build"); + let json = serde_json::to_value(&content).expect("content should serialize"); + assert_eq!(json[0]["type"], "text"); + assert_eq!(json[1]["type"], "image"); + assert_eq!(json[1]["source"]["type"], "base64"); + assert_eq!(json[1]["source"]["media_type"], "image/png"); + assert_eq!(json[1]["source"]["data"], "YWJj"); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index f906d96..79cfaf0 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,5 +1,7 @@ pub mod client; +mod media; mod messages; +mod openai_compatible; pub mod usage; mod utils; @@ -44,9 +46,17 @@ pub struct ModelInfo { /// Description of the model's capabilities pub description: Option, } +pub(crate) use media::{ + AnthropicMessageContent, OpenAICompatibleMessageContent, build_anthropic_message_content, + build_openai_compatible_message_content, +}; +pub(crate) use openai_compatible::{ + OpenAICompatibleChatCompletionRequest, OpenAICompatibleChatCompletionResponse, + OpenAICompatibleChatMessage, convert_openai_compatible_chat_messages, +}; pub(crate) use utils::{ ResponseFormat, check_response_status, generate_with_retry_with_history, handle_http_error, - parse_validate_and_create_output, prepare_strict_schema, + materialize_with_media_with_retry, parse_validate_and_create_output, prepare_strict_schema, }; /// Thinking level configuration for models that support extended reasoning. diff --git a/src/backend/openai.rs b/src/backend/openai.rs index adfbdc4..35e62da 100644 --- a/src/backend/openai.rs +++ b/src/backend/openai.rs @@ -1,15 +1,16 @@ use async_trait::async_trait; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; use std::str::FromStr; use std::time::Duration; use tracing::{debug, error, info, instrument, trace, warn}; use crate::backend::{ ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, MaterializeResult, - ModelInfo, ResponseFormat, ThinkingLevel, TokenUsage, ValidationFailureContext, - check_response_status, generate_with_retry_with_history, handle_http_error, - parse_validate_and_create_output, prepare_strict_schema, + ModelInfo, OpenAICompatibleChatCompletionRequest, OpenAICompatibleChatCompletionResponse, + OpenAICompatibleChatMessage, OpenAICompatibleMessageContent, ResponseFormat, ThinkingLevel, + TokenUsage, ValidationFailureContext, check_response_status, + convert_openai_compatible_chat_messages, generate_with_retry_with_history, handle_http_error, + materialize_with_media_with_retry, parse_validate_and_create_output, prepare_strict_schema, }; use crate::error::{ApiErrorKind, RStructorError, Result}; use crate::model::Instructor; @@ -42,6 +43,10 @@ pub enum Model { Gpt52Pro, /// GPT-5.2 (latest GPT-5 model) Gpt52, + /// GPT-5.2 Chat Latest (rolling latest chat-optimized GPT-5.2 model) + Gpt52ChatLatest, + /// GPT-5.2 Codex (latest coding-focused GPT-5.2 model) + Gpt52Codex, /// GPT-5.1 (GPT-5.1 model) Gpt51, /// GPT-5 Chat Latest (latest GPT-5 model for chat) @@ -91,6 +96,8 @@ impl Model { match self { Model::Gpt52Pro => "gpt-5.2-pro", Model::Gpt52 => "gpt-5.2", + Model::Gpt52ChatLatest => "gpt-5.2-chat-latest", + Model::Gpt52Codex => "gpt-5.2-codex", Model::Gpt51 => "gpt-5.1", Model::Gpt5ChatLatest => "gpt-5-chat-latest", Model::Gpt5Pro => "gpt-5-pro", @@ -124,6 +131,8 @@ impl Model { match name.as_str() { "gpt-5.2-pro" => Model::Gpt52Pro, "gpt-5.2" => Model::Gpt52, + "gpt-5.2-chat-latest" => Model::Gpt52ChatLatest, + "gpt-5.2-codex" => Model::Gpt52Codex, "gpt-5.1" => Model::Gpt51, "gpt-5-chat-latest" => Model::Gpt5ChatLatest, "gpt-5-pro" => Model::Gpt5Pro, @@ -192,59 +201,8 @@ pub struct OpenAIClient { client: reqwest::Client, } -// OpenAI API request and response structures -#[derive(Debug, Serialize)] -struct OpenAIChatMessage { - role: String, - content: String, -} - -// ResponseFormat and JsonSchemaFormat are now imported from utils - -#[derive(Debug, Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, - temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - /// Reasoning effort for GPT-5.x models - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_effort: Option, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct ResponseMessage { - role: String, - content: Option, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct ChatCompletionChoice { - message: ResponseMessage, - finish_reason: String, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct UsageInfo { - prompt_tokens: u64, - completion_tokens: u64, - #[serde(default)] - total_tokens: u64, -} - -#[derive(Debug, Deserialize)] -struct ChatCompletionResponse { - choices: Vec, - #[serde(default)] - usage: Option, - model: Option, -} +// ResponseFormat and JsonSchemaFormat are imported from utils and shared +// OpenAI-compatible chat completion request/response types are in openai_compatible.rs. impl OpenAIClient { /// Create a new OpenAI client with the provided API key. @@ -468,20 +426,15 @@ impl OpenAIClient { }; // Convert ChatMessage to OpenAI's format - let api_messages: Vec = messages - .iter() - .map(|msg| OpenAIChatMessage { - role: msg.role.as_str().to_string(), - content: msg.content.clone(), - }) - .collect(); + let api_messages = + convert_openai_compatible_chat_messages(messages, "OpenAI").map_err(|e| (e, None))?; // Build the request with native structured outputs debug!( "Building OpenAI API request with structured outputs (history_len={})", api_messages.len() ); - let request = ChatCompletionRequest { + let request = OpenAICompatibleChatCompletionRequest { model: self.config.model.as_str().to_string(), messages: api_messages, response_format: Some(response_format), @@ -514,10 +467,11 @@ impl OpenAIClient { .map_err(|e| (e, None))?; debug!("Successfully received response from OpenAI"); - let completion: ChatCompletionResponse = response.json().await.map_err(|e| { - error!(error = %e, "Failed to parse JSON response from OpenAI"); - (RStructorError::from(e), None) - })?; + let completion: OpenAICompatibleChatCompletionResponse = + response.json().await.map_err(|e| { + error!(error = %e, "Failed to parse JSON response from OpenAI"); + (RStructorError::from(e), None) + })?; if completion.choices.is_empty() { error!("OpenAI returned empty choices array"); @@ -601,6 +555,32 @@ impl LLMClient for OpenAIClient { Ok(output.data) } + #[instrument( + name = "openai_materialize_with_media", + skip(self, prompt, media), + fields( + type_name = std::any::type_name::(), + model = %self.config.model.as_str(), + prompt_len = prompt.len(), + media_len = media.len() + ) + )] + async fn materialize_with_media(&self, prompt: &str, media: &[super::MediaFile]) -> Result + where + T: Instructor + DeserializeOwned + Send + 'static, + { + materialize_with_media_with_retry( + |messages: Vec| { + let this = self; + async move { this.materialize_internal::(&messages).await } + }, + prompt, + media, + self.config.max_retries, + ) + .await + } + #[instrument( name = "openai_materialize_with_metadata", skip(self, prompt), @@ -669,11 +649,11 @@ impl LLMClient for OpenAIClient { // Build the request for text generation (no structured output) debug!("Building OpenAI API request for text generation"); - let request = ChatCompletionRequest { + let request = OpenAICompatibleChatCompletionRequest { model: self.config.model.as_str().to_string(), - messages: vec![OpenAIChatMessage { + messages: vec![OpenAICompatibleChatMessage { role: "user".to_string(), - content: prompt.to_string(), + content: OpenAICompatibleMessageContent::Text(prompt.to_string()), }], response_format: None, temperature: effective_temp, @@ -703,10 +683,11 @@ impl LLMClient for OpenAIClient { let response = check_response_status(response, "OpenAI").await?; debug!("Successfully received response from OpenAI"); - let completion: ChatCompletionResponse = response.json().await.map_err(|e| { - error!(error = %e, "Failed to parse JSON response from OpenAI"); - e - })?; + let completion: OpenAICompatibleChatCompletionResponse = + response.json().await.map_err(|e| { + error!(error = %e, "Failed to parse JSON response from OpenAI"); + e + })?; if completion.choices.is_empty() { error!("OpenAI returned empty choices array"); diff --git a/src/backend/openai_compatible.rs b/src/backend/openai_compatible.rs new file mode 100644 index 0000000..96c56ae --- /dev/null +++ b/src/backend/openai_compatible.rs @@ -0,0 +1,107 @@ +use serde::{Deserialize, Serialize}; + +use crate::backend::{ + ChatMessage, OpenAICompatibleMessageContent, ResponseFormat, + build_openai_compatible_message_content, +}; +use crate::error::Result; + +#[derive(Debug, Serialize)] +pub(crate) struct OpenAICompatibleChatMessage { + pub role: String, + pub content: OpenAICompatibleMessageContent, +} + +pub(crate) fn convert_openai_compatible_chat_messages( + messages: &[ChatMessage], + provider_name: &str, +) -> Result> { + messages + .iter() + .map(|msg| { + Ok(OpenAICompatibleChatMessage { + role: msg.role.as_str().to_string(), + content: build_openai_compatible_message_content(msg, provider_name)?, + }) + }) + .collect() +} + +#[derive(Debug, Serialize)] +pub(crate) struct OpenAICompatibleChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + /// Reasoning effort for GPT-5.x models (OpenAI only). + /// Omitted for providers that don't support it. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +pub(crate) struct OpenAICompatibleResponseMessage { + pub role: String, + pub content: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +pub(crate) struct OpenAICompatibleChatCompletionChoice { + pub message: OpenAICompatibleResponseMessage, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +pub(crate) struct OpenAICompatibleUsageInfo { + pub prompt_tokens: u64, + pub completion_tokens: u64, + #[serde(default)] + pub total_tokens: u64, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct OpenAICompatibleChatCompletionResponse { + pub choices: Vec, + #[serde(default)] + pub usage: Option, + pub model: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::MediaFile; + + #[test] + fn test_convert_openai_compatible_chat_messages_text_only() { + let messages = vec![ChatMessage::user("hello")]; + let converted = convert_openai_compatible_chat_messages(&messages, "OpenAI") + .expect("conversion should succeed"); + + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "user"); + let json = serde_json::to_value(&converted[0]).expect("serialization should succeed"); + assert_eq!(json["content"], serde_json::json!("hello")); + } + + #[test] + fn test_convert_openai_compatible_chat_messages_with_media() { + let messages = vec![ChatMessage::user_with_media( + "describe image", + vec![MediaFile::from_bytes(b"abc", "image/png")], + )]; + let converted = convert_openai_compatible_chat_messages(&messages, "OpenAI") + .expect("conversion should succeed"); + + assert_eq!(converted.len(), 1); + let json = serde_json::to_value(&converted[0]).expect("serialization should succeed"); + assert_eq!(json["content"][0]["type"], "text"); + assert_eq!(json["content"][1]["type"], "image_url"); + } +} diff --git a/src/backend/utils.rs b/src/backend/utils.rs index 96cbab9..88d7ab9 100644 --- a/src/backend/utils.rs +++ b/src/backend/utils.rs @@ -1137,10 +1137,37 @@ pub async fn check_response_status(response: Response, provider_name: &str) -> R /// * `prompt` - The initial user prompt /// * `max_retries` - Maximum number of retry attempts (None or 0 means no retries) pub async fn generate_with_retry_with_history( - mut generate_fn: F, + generate_fn: F, prompt: &str, max_retries: Option, ) -> Result> +where + F: FnMut(Vec) -> Fut, + Fut: std::future::Future< + Output = std::result::Result< + MaterializeInternalOutput, + (RStructorError, Option), + >, + >, +{ + generate_with_retry_with_initial_messages( + generate_fn, + vec![ChatMessage::user(prompt)], + max_retries, + ) + .await +} + +/// Helper function to execute generation with retry logic using a custom initial +/// conversation history. +/// +/// This is primarily used for multimodal prompts where the initial user message +/// may contain attached media in addition to text. +pub async fn generate_with_retry_with_initial_messages( + mut generate_fn: F, + initial_messages: Vec, + max_retries: Option, +) -> Result> where F: FnMut(Vec) -> Fut, Fut: std::future::Future< @@ -1151,15 +1178,14 @@ where >, { let Some(max_retries) = max_retries.filter(|&n| n > 0) else { - // No retries configured - just run once with a single user message - let messages = vec![ChatMessage::user(prompt)]; - return generate_fn(messages).await.map_err(|(err, _)| err); + // No retries configured - just run once with the provided initial messages + return generate_fn(initial_messages).await.map_err(|(err, _)| err); }; let max_attempts = max_retries + 1; // +1 for initial attempt - // Initialize conversation history with the original user prompt - let mut messages = vec![ChatMessage::user(prompt)]; + // Initialize conversation history with the provided starting messages. + let mut messages = initial_messages; trace!( "Starting structured generation with conversation history: max_attempts={}", @@ -1278,6 +1304,31 @@ where unreachable!() } +/// Helper for provider implementations of `materialize_with_media`. +/// +/// Builds an initial media-bearing user message and runs the shared retry/history flow. +pub async fn materialize_with_media_with_retry( + generate_fn: F, + prompt: &str, + media: &[crate::backend::client::MediaFile], + max_retries: Option, +) -> Result +where + F: FnMut(Vec) -> Fut, + Fut: std::future::Future< + Output = std::result::Result< + MaterializeInternalOutput, + (RStructorError, Option), + >, + >, +{ + let initial_messages = vec![ChatMessage::user_with_media(prompt, media.to_vec())]; + let output = + generate_with_retry_with_initial_messages(generate_fn, initial_messages, max_retries) + .await?; + Ok(output.data) +} + /// Macro to generate standard builder methods for LLM clients. /// /// This macro generates `model()`, `temperature()`, `max_tokens()`, and `timeout()` methods @@ -2071,4 +2122,79 @@ mod tests { "x-enum-keys should be stripped from non-map schemas" ); } + + #[tokio::test] + async fn test_generate_with_retry_with_initial_messages_preserves_media() { + let initial = vec![ChatMessage::user_with_media( + "describe image", + vec![crate::backend::client::MediaFile::from_bytes( + b"hello-image", + "image/png", + )], + )]; + + let output = generate_with_retry_with_initial_messages( + |messages: Vec| async move { + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].media.len(), 1); + Ok(MaterializeInternalOutput::new( + "ok".to_string(), + "{\"ok\":true}".to_string(), + None, + )) + }, + initial, + Some(0), + ) + .await + .expect("generation should succeed"); + + assert_eq!(output.data, "ok"); + } + + #[tokio::test] + async fn test_generate_with_retry_with_initial_messages_adds_feedback_history() { + let initial = vec![ChatMessage::user_with_media( + "describe image", + vec![crate::backend::client::MediaFile::from_bytes( + b"hello-image", + "image/png", + )], + )]; + + let mut attempts = 0usize; + let output = generate_with_retry_with_initial_messages( + |messages: Vec| { + attempts += 1; + async move { + if attempts == 1 { + Err(( + RStructorError::ValidationError("schema validation failed".to_string()), + Some(ValidationFailureContext::new( + "missing required field: summary", + "{\"subject\":\"rust\"}", + )), + )) + } else { + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].media.len(), 1); + assert_eq!(messages[1].role, crate::backend::ChatRole::Assistant); + assert_eq!(messages[2].role, crate::backend::ChatRole::User); + Ok(MaterializeInternalOutput::new( + "ok".to_string(), + "{\"ok\":true}".to_string(), + None, + )) + } + } + }, + initial, + Some(1), + ) + .await + .expect("generation should succeed after retry"); + + assert_eq!(attempts, 2); + assert_eq!(output.data, "ok"); + } } diff --git a/tests/anthropic_multimodal_tests.rs b/tests/anthropic_multimodal_tests.rs new file mode 100644 index 0000000..e9f30c0 --- /dev/null +++ b/tests/anthropic_multimodal_tests.rs @@ -0,0 +1,98 @@ +//! Integration tests for Anthropic multimodal structured extraction. +//! +//! Requires: +//! - `ANTHROPIC_API_KEY` +//! - `--features anthropic` + +#[path = "common/mod.rs"] +mod common; + +#[cfg(test)] +mod anthropic_multimodal_tests { + #[cfg(feature = "anthropic")] + use rstructor::AnthropicClient; + use rstructor::{AnthropicModel, Instructor, LLMClient}; + use serde::{Deserialize, Serialize}; + + use crate::common::{RUST_LOGO_MIME, RUST_LOGO_URL, download_media, media_url}; + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct ImageSummary { + subject: String, + summary: String, + } + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct MultiImageSummary { + image_count: u8, + summary: String, + } + + #[cfg(feature = "anthropic")] + #[tokio::test] + async fn test_anthropic_multimodal_inline_image() { + let media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let client = AnthropicClient::from_env() + .expect("ANTHROPIC_API_KEY must be set for this test") + .model(AnthropicModel::ClaudeOpus46) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Identify the main subject in this image and summarize it briefly.", + &[media], + ) + .await + .expect("Anthropic multimodal inline request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "anthropic")] + #[tokio::test] + async fn test_anthropic_multimodal_url_image() { + let media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = AnthropicClient::from_env() + .expect("ANTHROPIC_API_KEY must be set for this test") + .model(AnthropicModel::ClaudeOpus46) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Describe this image in one concise sentence, focusing on scene type.", + &[media], + ) + .await + .expect("Anthropic multimodal URL request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "anthropic")] + #[tokio::test] + async fn test_anthropic_multimodal_multiple_images() { + let rust_media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let lake_media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = AnthropicClient::from_env() + .expect("ANTHROPIC_API_KEY must be set for this test") + .model(AnthropicModel::ClaudeOpus46) + .temperature(0.0); + + let result: MultiImageSummary = client + .materialize_with_media( + "You are given two images. Return the exact count in image_count and summarize both images.", + &[rust_media, lake_media], + ) + .await + .expect("Anthropic multimodal multi-image request failed"); + + assert!( + result.image_count >= 2, + "expected at least 2 images, got {}", + result.image_count + ); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..4ec088e --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,24 @@ +use rstructor::MediaFile; + +pub const RUST_LOGO_URL: &str = "https://www.rust-lang.org/logos/rust-logo-512x512.png"; +pub const RUST_LOGO_MIME: &str = "image/png"; + +#[allow(dead_code)] +pub const RUST_SOCIAL_URL: &str = "https://www.rust-lang.org/static/images/rust-social-wide.jpg"; +#[allow(dead_code)] +pub const RUST_SOCIAL_MIME: &str = "image/jpeg"; + +pub async fn download_media(url: &str, mime: &str) -> MediaFile { + let bytes = reqwest::get(url) + .await + .expect("Failed to download media fixture") + .bytes() + .await + .expect("Failed to read media fixture bytes"); + MediaFile::from_bytes(&bytes, mime) +} + +#[allow(dead_code)] +pub fn media_url(url: &str, mime: &str) -> MediaFile { + MediaFile::new(url, mime) +} diff --git a/tests/gemini_multimodal_tests.rs b/tests/gemini_multimodal_tests.rs index 37b25a6..3991397 100644 --- a/tests/gemini_multimodal_tests.rs +++ b/tests/gemini_multimodal_tests.rs @@ -7,13 +7,20 @@ //! cargo test --test gemini_multimodal_tests --features gemini //! ``` +#[path = "common/mod.rs"] +mod common; + #[cfg(test)] mod gemini_multimodal_tests { #[cfg(feature = "gemini")] - use rstructor::{GeminiClient, GeminiModel}; - use rstructor::{Instructor, LLMClient, MediaFile}; + use rstructor::GeminiClient; + use rstructor::{GeminiModel, Instructor, LLMClient}; use serde::{Deserialize, Serialize}; + use crate::common::{ + RUST_LOGO_MIME, RUST_LOGO_URL, RUST_SOCIAL_MIME, RUST_SOCIAL_URL, download_media, + }; + #[derive(Instructor, Serialize, Deserialize, Debug)] #[llm(description = "Description of an image")] struct ImageDescription { @@ -27,19 +34,16 @@ mod gemini_multimodal_tests { description: String, } + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct MultiImageSummary { + image_count: u8, + summary: String, + } + #[cfg(feature = "gemini")] #[tokio::test] async fn test_gemini_multimodal_image_analysis() { - // Download a small, stable public image (Rust logo) - let image_url = "https://www.rust-lang.org/logos/rust-logo-512x512.png"; - let image_bytes = reqwest::get(image_url) - .await - .expect("Failed to download test image") - .bytes() - .await - .expect("Failed to read image bytes"); - - let media = MediaFile::from_bytes(&image_bytes, "image/png"); + let media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; let client = GeminiClient::from_env() .expect("GEMINI_API_KEY must be set for this test") @@ -58,22 +62,58 @@ mod gemini_multimodal_tests { !result.description.is_empty(), "Description should not be empty" ); + } + + #[cfg(feature = "gemini")] + #[tokio::test] + async fn test_gemini_multimodal_second_real_world_image() { + let media = download_media(RUST_SOCIAL_URL, RUST_SOCIAL_MIME).await; + + let client = GeminiClient::from_env() + .expect("GEMINI_API_KEY must be set for this test") + .model(GeminiModel::Gemini3FlashPreview) + .temperature(0.0); + + let result: ImageDescription = client + .materialize_with_media( + "Describe the environment in this image and list the dominant colors.", + &[media], + ) + .await + .expect("Failed to materialize secondary real-world image description"); + + assert!(!result.subject.is_empty(), "Subject should not be empty"); + assert!(!result.colors.is_empty(), "Colors should not be empty"); + assert!( + !result.description.is_empty(), + "Description should not be empty" + ); + } + + #[cfg(feature = "gemini")] + #[tokio::test] + async fn test_gemini_multimodal_multiple_images() { + let rust_media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let social_media = download_media(RUST_SOCIAL_URL, RUST_SOCIAL_MIME).await; + + let client = GeminiClient::from_env() + .expect("GEMINI_API_KEY must be set for this test") + .model(GeminiModel::Gemini3FlashPreview) + .temperature(0.0); + + let result: MultiImageSummary = client + .materialize_with_media( + "You are given two images. Return the exact count in image_count and summarize both images.", + &[rust_media, social_media], + ) + .await + .expect("Failed to materialize multi-image summary"); - // The Rust logo should be recognized as related to Rust or as a logo/gear - let subject_lower = result.subject.to_lowercase(); - let desc_lower = result.description.to_lowercase(); - let mentions_rust_or_logo = subject_lower.contains("rust") - || subject_lower.contains("logo") - || subject_lower.contains("gear") - || subject_lower.contains("cog") - || desc_lower.contains("rust") - || desc_lower.contains("logo") - || desc_lower.contains("gear") - || desc_lower.contains("cog"); assert!( - mentions_rust_or_logo, - "Expected the image to be recognized as the Rust logo/gear, got subject='{}', description='{}'", - result.subject, result.description + result.image_count >= 2, + "expected at least 2 images, got {}", + result.image_count ); + assert!(!result.summary.is_empty(), "summary should not be empty"); } } diff --git a/tests/grok_multimodal_tests.rs b/tests/grok_multimodal_tests.rs new file mode 100644 index 0000000..90bf485 --- /dev/null +++ b/tests/grok_multimodal_tests.rs @@ -0,0 +1,98 @@ +//! Integration tests for Grok multimodal structured extraction. +//! +//! Requires: +//! - `XAI_API_KEY` +//! - `--features grok` + +#[path = "common/mod.rs"] +mod common; + +#[cfg(test)] +mod grok_multimodal_tests { + #[cfg(feature = "grok")] + use rstructor::GrokClient; + use rstructor::{GrokModel, Instructor, LLMClient}; + use serde::{Deserialize, Serialize}; + + use crate::common::{RUST_LOGO_MIME, RUST_LOGO_URL, download_media, media_url}; + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct ImageSummary { + subject: String, + summary: String, + } + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct MultiImageSummary { + image_count: u8, + summary: String, + } + + #[cfg(feature = "grok")] + #[tokio::test] + async fn test_grok_multimodal_inline_image() { + let media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let client = GrokClient::from_env() + .expect("XAI_API_KEY must be set for this test") + .model(GrokModel::Grok41FastNonReasoning) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Identify the main subject in this image and summarize it briefly.", + &[media], + ) + .await + .expect("Grok multimodal inline request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "grok")] + #[tokio::test] + async fn test_grok_multimodal_url_image() { + let media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = GrokClient::from_env() + .expect("XAI_API_KEY must be set for this test") + .model(GrokModel::Grok41FastNonReasoning) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Describe this image in one concise sentence, focusing on scene type.", + &[media], + ) + .await + .expect("Grok multimodal URL request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "grok")] + #[tokio::test] + async fn test_grok_multimodal_multiple_images() { + let rust_media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let lake_media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = GrokClient::from_env() + .expect("XAI_API_KEY must be set for this test") + .model(GrokModel::Grok41FastNonReasoning) + .temperature(0.0); + + let result: MultiImageSummary = client + .materialize_with_media( + "You are given two images. Return the exact count in image_count and summarize both images.", + &[rust_media, lake_media], + ) + .await + .expect("Grok multimodal multi-image request failed"); + + assert!( + result.image_count >= 2, + "expected at least 2 images, got {}", + result.image_count + ); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } +} diff --git a/tests/model_string_test.rs b/tests/model_string_test.rs index 5d30fc9..f738d15 100644 --- a/tests/model_string_test.rs +++ b/tests/model_string_test.rs @@ -20,6 +20,12 @@ mod tests { let model = OpenAIModel::from_str("gpt-4o-mini").unwrap(); assert_eq!(model, OpenAIModel::Gpt4OMini); + let model = OpenAIModel::from_str("gpt-5.2-chat-latest").unwrap(); + assert_eq!(model, OpenAIModel::Gpt52ChatLatest); + + let model = OpenAIModel::from_str("gpt-5.2-codex").unwrap(); + assert_eq!(model, OpenAIModel::Gpt52Codex); + // Test From<&str> let model: OpenAIModel = "gpt-3.5-turbo".into(); assert_eq!(model, OpenAIModel::Gpt35Turbo); @@ -95,6 +101,9 @@ mod tests { let model = GeminiModel::from_string("gemini-2.5-flash"); assert_eq!(model, GeminiModel::Gemini25Flash); + let model = GeminiModel::from_string("gemini-2.5-flash-image"); + assert_eq!(model, GeminiModel::Gemini25FlashImage); + // Test custom model let model = GeminiModel::from_string("gemini-custom"); match model { diff --git a/tests/openai_multimodal_tests.rs b/tests/openai_multimodal_tests.rs new file mode 100644 index 0000000..11e395a --- /dev/null +++ b/tests/openai_multimodal_tests.rs @@ -0,0 +1,98 @@ +//! Integration tests for OpenAI multimodal structured extraction. +//! +//! Requires: +//! - `OPENAI_API_KEY` +//! - `--features openai` + +#[path = "common/mod.rs"] +mod common; + +#[cfg(test)] +mod openai_multimodal_tests { + #[cfg(feature = "openai")] + use rstructor::OpenAIClient; + use rstructor::{Instructor, LLMClient, OpenAIModel}; + use serde::{Deserialize, Serialize}; + + use crate::common::{RUST_LOGO_MIME, RUST_LOGO_URL, download_media, media_url}; + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct ImageSummary { + subject: String, + summary: String, + } + + #[derive(Instructor, Serialize, Deserialize, Debug)] + struct MultiImageSummary { + image_count: u8, + summary: String, + } + + #[cfg(feature = "openai")] + #[tokio::test] + async fn test_openai_multimodal_inline_image() { + let media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let client = OpenAIClient::from_env() + .expect("OPENAI_API_KEY must be set for this test") + .model(OpenAIModel::Gpt52) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Identify the main subject in this image and summarize it briefly.", + &[media], + ) + .await + .expect("OpenAI multimodal inline request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "openai")] + #[tokio::test] + async fn test_openai_multimodal_url_image() { + let media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = OpenAIClient::from_env() + .expect("OPENAI_API_KEY must be set for this test") + .model(OpenAIModel::Gpt52) + .temperature(0.0); + + let result: ImageSummary = client + .materialize_with_media( + "Describe this image in one concise sentence, focusing on scene type.", + &[media], + ) + .await + .expect("OpenAI multimodal URL request failed"); + + assert!(!result.subject.is_empty(), "subject should not be empty"); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } + + #[cfg(feature = "openai")] + #[tokio::test] + async fn test_openai_multimodal_multiple_images() { + let rust_media = download_media(RUST_LOGO_URL, RUST_LOGO_MIME).await; + let lake_media = media_url(RUST_LOGO_URL, RUST_LOGO_MIME); + let client = OpenAIClient::from_env() + .expect("OPENAI_API_KEY must be set for this test") + .model(OpenAIModel::Gpt52) + .temperature(0.0); + + let result: MultiImageSummary = client + .materialize_with_media( + "You are given two images. Return the exact count in image_count and summarize both images.", + &[rust_media, lake_media], + ) + .await + .expect("OpenAI multimodal multi-image request failed"); + + assert!( + result.image_count >= 2, + "expected at least 2 images, got {}", + result.image_count + ); + assert!(!result.summary.is_empty(), "summary should not be empty"); + } +}