Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions rstructor_derive/src/generators/enum_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ fn generate_externally_tagged_enum_schema(
];

let mut schema_obj = ::serde_json::json!({
"oneOf": variant_schemas,
"anyOf": variant_schemas,
"title": stringify!(#name)
});

Expand Down Expand Up @@ -754,7 +754,7 @@ fn generate_internally_tagged_enum_schema(
];

let mut schema_obj = ::serde_json::json!({
"oneOf": variant_schemas,
"anyOf": variant_schemas,
"title": stringify!(#name)
});

Expand Down Expand Up @@ -1045,7 +1045,7 @@ fn generate_adjacently_tagged_enum_schema(
];

let mut schema_obj = ::serde_json::json!({
"oneOf": variant_schemas,
"anyOf": variant_schemas,
"title": stringify!(#name)
});

Expand Down Expand Up @@ -1202,7 +1202,7 @@ fn generate_untagged_enum_schema(
];

let mut schema_obj = ::serde_json::json!({
"oneOf": variant_schemas,
"anyOf": variant_schemas,
"title": stringify!(#name)
});

Expand Down
8 changes: 4 additions & 4 deletions rstructor_derive/tests/enum_with_data_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ fn test_enum_with_data_schema() {
let schema_obj = UserStatus::schema();
let schema = schema_obj.to_json();

// Check that we're using oneOf for complex enums
// Check that we're using anyOf for complex enums
assert!(
schema.get("oneOf").is_some(),
"Schema should use oneOf for enums with associated data"
schema.get("anyOf").is_some(),
"Schema should use anyOf for enums with associated data"
);

if let Some(Value::Array(variants)) = schema.get("oneOf") {
if let Some(Value::Array(variants)) = schema.get("anyOf") {
// Should have 4 variants
assert_eq!(variants.len(), 4, "Should have 4 variants");
}
Expand Down
2 changes: 1 addition & 1 deletion src/backend/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl GeminiClient {
/// # Ok(())
/// # }
/// ```
#[instrument(name = "gemini_client_new", skip(api_key), fields(model = ?Model::Gemini25Flash))]
#[instrument(name = "gemini_client_new", skip(api_key), fields(model = ?Model::Gemini3FlashPreview))]
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
Expand Down
282 changes: 189 additions & 93 deletions src/backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,12 @@ pub struct AdjacentlyTaggedEnumInfo {
/// Extract adjacently tagged enum info from a schema (before Gemini transformation)
/// Searches recursively through the schema tree
pub fn extract_adjacently_tagged_info(schema: &Value) -> Option<AdjacentlyTaggedEnumInfo> {
// First check if this level has oneOf
if let Some(one_of) = schema.get("oneOf").and_then(|v| v.as_array()) {
let mut tag_key = None;
let mut content_key = None;
let mut tag_values = Vec::new();

for variant in one_of {
if let Some((t, c, v)) = detect_adjacently_tagged_variant(variant) {
if tag_key.is_none() {
tag_key = Some(t);
content_key = Some(c);
}
tag_values.push(v);
}
}

if let (Some(tag), Some(content)) = (tag_key, content_key) {
return Some(AdjacentlyTaggedEnumInfo {
tag_key: tag,
content_key: content,
tag_values,
});
// First check if this level has enum disjunction variants
for key in ["oneOf", "anyOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array())
&& let Some(info) = extract_adjacently_tagged_info_from_variants(variants)
{
return Some(info);
}
}

Expand All @@ -217,7 +201,7 @@ pub fn extract_adjacently_tagged_info(schema: &Value) -> Option<AdjacentlyTagged
}

// Search in allOf, anyOf, oneOf
for key in &["allOf", "anyOf"] {
for key in &["allOf", "anyOf", "oneOf"] {
if let Some(arr) = schema.get(key).and_then(|v| v.as_array()) {
for item in arr {
if let Some(info) = extract_adjacently_tagged_info(item) {
Expand All @@ -230,6 +214,34 @@ pub fn extract_adjacently_tagged_info(schema: &Value) -> Option<AdjacentlyTagged
None
}

fn extract_adjacently_tagged_info_from_variants(
variants: &[Value],
) -> Option<AdjacentlyTaggedEnumInfo> {
let mut tag_key = None;
let mut content_key = None;
let mut tag_values = Vec::new();

for variant in variants {
if let Some((t, c, v)) = detect_adjacently_tagged_variant(variant) {
if tag_key.is_none() {
tag_key = Some(t);
content_key = Some(c);
}
tag_values.push(v);
}
}

if let (Some(tag), Some(content)) = (tag_key, content_key) {
Some(AdjacentlyTaggedEnumInfo {
tag_key: tag,
content_key: content,
tag_values,
})
} else {
None
}
}

/// Transform internally tagged JSON back to adjacently tagged format
pub fn transform_internally_to_adjacently_tagged(
json: &mut Value,
Expand Down Expand Up @@ -551,6 +563,55 @@ fn transform_adjacently_tagged_to_internally_tagged(
Value::Object(obj)
}

fn normalize_adjacently_tagged_variants(variants: &mut Vec<Value>) {
// First, check if this looks like an adjacently tagged enum.
// All variants should have the same tag/content keys.
let mut adjacently_tagged_info: Option<(String, String)> = None;
let mut all_adjacently_tagged = true;

for variant in variants.iter() {
if let Some((tag_key, content_key, _tag_value)) = detect_adjacently_tagged_variant(variant)
{
if let Some((ref existing_tag, ref existing_content)) = adjacently_tagged_info {
if tag_key != *existing_tag || content_key != *existing_content {
all_adjacently_tagged = false;
break;
}
} else {
adjacently_tagged_info = Some((tag_key, content_key));
}
} else {
// Unit variant (only tag, no content) is still okay.
if let Some(variant_obj) = variant.as_object()
&& let Some(props) = variant_obj.get("properties").and_then(|p| p.as_object())
&& props.len() == 1
&& variant_obj
.get("required")
.and_then(|r| r.as_array())
.map(|a| a.len())
== Some(1)
{
continue;
}
all_adjacently_tagged = false;
break;
}
}

if all_adjacently_tagged && adjacently_tagged_info.is_some() {
*variants = variants
.iter()
.map(|variant| {
if let Some((t, c, v)) = detect_adjacently_tagged_variant(variant) {
transform_adjacently_tagged_to_internally_tagged(variant, &t, &c, &v)
} else {
variant.clone()
}
})
.collect();
}
}

/// Internal function that strips unsupported keywords after refs are resolved.
fn strip_gemini_unsupported_keywords_recursive(schema: &mut Value) {
if let Some(obj) = schema.as_object_mut() {
Expand Down Expand Up @@ -732,78 +793,16 @@ fn strip_gemini_unsupported_keywords_recursive(schema: &mut Value) {
}
}

// Process 'anyOf' array
if let Some(any_of) = obj.get_mut("anyOf")
&& let Some(arr) = any_of.as_array_mut()
{
for item in arr.iter_mut() {
strip_gemini_unsupported_keywords_recursive(item);
}
}

// Process 'oneOf' array
if let Some(one_of) = obj.get_mut("oneOf")
&& let Some(arr) = one_of.as_array_mut()
{
// First, check if this looks like an adjacently tagged enum
// All variants should have the same tag/content keys
let mut adjacently_tagged_info: Option<(String, String)> = None;
let mut all_adjacently_tagged = true;

for item in arr.iter() {
if let Some((tag_key, content_key, _tag_value)) =
detect_adjacently_tagged_variant(item)
{
if let Some((ref existing_tag, ref existing_content)) = adjacently_tagged_info {
// Check if keys match
if tag_key != *existing_tag || content_key != *existing_content {
all_adjacently_tagged = false;
break;
}
} else {
adjacently_tagged_info = Some((tag_key, content_key));
}
} else {
// Unit variant (only tag, no content) is still okay
// Check if it has just one required field with enum
if let Some(variant_obj) = item.as_object()
&& let Some(props) =
variant_obj.get("properties").and_then(|p| p.as_object())
&& props.len() == 1
&& variant_obj
.get("required")
.and_then(|r| r.as_array())
.map(|a| a.len())
== Some(1)
{
// This is a unit variant, keep checking
continue;
}
all_adjacently_tagged = false;
break;
// Process enum disjunction arrays and normalize adjacently tagged variants.
for key in ["anyOf", "oneOf"] {
if let Some(disjunction) = obj.get_mut(key)
&& let Some(variants) = disjunction.as_array_mut()
{
normalize_adjacently_tagged_variants(variants);
for variant in variants.iter_mut() {
strip_gemini_unsupported_keywords_recursive(variant);
}
}

// If all variants are adjacently tagged, transform them
if all_adjacently_tagged && adjacently_tagged_info.is_some() {
// Transform each variant
*arr = arr
.iter()
.map(|item| {
if let Some((t, c, v)) = detect_adjacently_tagged_variant(item) {
transform_adjacently_tagged_to_internally_tagged(item, &t, &c, &v)
} else {
// Unit variant - leave as is
item.clone()
}
})
.collect();
}

// Now recursively process all variants
for item in arr.iter_mut() {
strip_gemini_unsupported_keywords_recursive(item);
}
}

// Handle additionalProperties if it's a schema object (for maps) - recurse into it
Expand Down Expand Up @@ -1902,6 +1901,103 @@ mod tests {
);
}

#[test]
fn test_extract_adjacently_tagged_info_anyof() {
let schema = serde_json::json!({
"anyOf": [
{
"type": "object",
"properties": {
"status": { "type": "string", "enum": ["Success"] },
"data": {
"type": "object",
"properties": {
"output": { "type": "string" }
},
"required": ["output"]
}
},
"required": ["status", "data"]
},
{
"type": "object",
"properties": {
"status": { "type": "string", "enum": ["Failure"] },
"data": {
"type": "object",
"properties": {
"reason": { "type": "string" }
},
"required": ["reason"]
}
},
"required": ["status", "data"]
}
]
});

let info = extract_adjacently_tagged_info(&schema).expect("should detect anyOf enum info");
assert_eq!(info.tag_key, "status");
assert_eq!(info.content_key, "data");
assert_eq!(info.tag_values.len(), 2);
assert!(info.tag_values.contains(&"Success".to_string()));
assert!(info.tag_values.contains(&"Failure".to_string()));
}

#[test]
fn test_gemini_anyof_adjacently_tagged_variants_are_flattened() {
let mut schema = serde_json::json!({
"anyOf": [
{
"type": "object",
"properties": {
"status": { "type": "string", "enum": ["Success"] },
"data": {
"type": "object",
"properties": {
"output": { "type": "string" }
},
"required": ["output"]
}
},
"required": ["status", "data"],
"description": "Success variant"
},
{
"type": "object",
"properties": {
"status": { "type": "string", "enum": ["Failure"] },
"data": {
"type": "object",
"properties": {
"reason": { "type": "string" }
},
"required": ["reason"]
}
},
"required": ["status", "data"],
"description": "Failure variant"
}
]
});

strip_gemini_unsupported_keywords_recursive(&mut schema);

let first_props = schema["anyOf"][0]["properties"]
.as_object()
.expect("properties should be object");
assert!(first_props.contains_key("status"));
assert!(first_props.contains_key("output"));
assert!(!first_props.contains_key("data"));

let first_required = schema["anyOf"][0]["required"]
.as_array()
.expect("required should be array");
assert!(first_required.contains(&serde_json::json!("status")));
assert!(first_required.contains(&serde_json::json!("output")));
assert!(!first_required.contains(&serde_json::json!("data")));
}

#[test]
fn test_gemini_map_with_x_enum_keys() {
let mut schema = serde_json::json!({
Expand Down
Loading
Loading