diff --git a/example_with_targets/src/tests.rs b/example_with_targets/src/tests.rs index 4de5b0a..950319c 100644 --- a/example_with_targets/src/tests.rs +++ b/example_with_targets/src/tests.rs @@ -42,3 +42,34 @@ fn test_target_b() -> Result<()> { assert_eq!(result, expected); Ok(()) } + +#[test] +fn test_null_field_skipping() -> Result<()> { + let test_function_none = + |_input: crate::schema::target_a::Input| -> Result { + Ok(crate::schema::FunctionTargetAResult { + status: None, // This should not appear in serialized output + }) + }; + + let test_function_some = + |_input: crate::schema::target_a::Input| -> Result { + Ok(crate::schema::FunctionTargetAResult { + status: Some(200), // This should appear in serialized output + }) + }; + + let test_input = r#"{ + "id": "gid://shopify/Order/1234567890", + "num": 123, + "name": "test" + }"#; + + let result_none = run_function_with_input(test_function_none, test_input)?; + let result_some = run_function_with_input(test_function_some, test_input)?; + + assert_eq!(result_none.status, None); + assert_eq!(result_some.status, Some(200)); + + Ok(()) +} diff --git a/integration_tests/src/lib.rs b/integration_tests/src/lib.rs index c7beb98..f0ecd60 100644 --- a/integration_tests/src/lib.rs +++ b/integration_tests/src/lib.rs @@ -33,7 +33,7 @@ fn build_example(name: &str) -> Result<()> { } static FUNCTION_RUNNER_PATH: LazyLock> = LazyLock::new(|| { - let path = workspace_root().join(format!("tmp/function-runner-{}", FUNCTION_RUNNER_VERSION)); + let path = workspace_root().join(format!("tmp/function-runner-{FUNCTION_RUNNER_VERSION}")); if !path.exists() { std::fs::create_dir_all(workspace_root().join("tmp"))?; @@ -44,7 +44,7 @@ static FUNCTION_RUNNER_PATH: LazyLock> = LazyLock::new(| }); static TRAMPOLINE_PATH: LazyLock> = LazyLock::new(|| { - let path = workspace_root().join(format!("tmp/trampoline-{}", TRAMPOLINE_VERSION)); + let path = workspace_root().join(format!("tmp/trampoline-{TRAMPOLINE_VERSION}")); if !path.exists() { std::fs::create_dir_all(workspace_root().join("tmp"))?; download_trampoline(&path)?; @@ -56,8 +56,7 @@ fn download_function_runner(destination: &PathBuf) -> Result<()> { download_from_github( |target_arch, target_os| { format!( - "https://github.com/Shopify/function-runner/releases/download/v{}/function-runner-{}-{}-v{}.gz", - FUNCTION_RUNNER_VERSION, target_arch, target_os, FUNCTION_RUNNER_VERSION, + "https://github.com/Shopify/function-runner/releases/download/v{FUNCTION_RUNNER_VERSION}/function-runner-{target_arch}-{target_os}-v{FUNCTION_RUNNER_VERSION}.gz" ) }, destination, @@ -68,8 +67,7 @@ fn download_trampoline(destination: &PathBuf) -> Result<()> { download_from_github( |target_arch, target_os| { format!( - "https://github.com/Shopify/shopify-function-wasm-api/releases/download/shopify_function_trampoline/v{}/shopify-function-trampoline-{}-{}-v{}.gz", - TRAMPOLINE_VERSION, target_arch, target_os, TRAMPOLINE_VERSION, + "https://github.com/Shopify/shopify-function-wasm-api/releases/download/shopify_function_trampoline/v{TRAMPOLINE_VERSION}/shopify-function-trampoline-{target_arch}-{target_os}-v{TRAMPOLINE_VERSION}.gz" ) }, destination, @@ -127,10 +125,10 @@ pub fn prepare_example(name: &str) -> Result { build_example(name)?; let wasm_path = workspace_root() .join("target/wasm32-wasip1/release") - .join(format!("{}.wasm", name)); + .join(format!("{name}.wasm")); let trampolined_path = workspace_root() .join("target/wasm32-wasip1/release") - .join(format!("{}-trampolined.wasm", name)); + .join(format!("{name}-trampolined.wasm")); let trampoline_path = TRAMPOLINE_PATH .as_ref() .map_err(|e| anyhow::anyhow!("Failed to download trampoline: {}", e))?; diff --git a/shopify_function_macro/src/lib.rs b/shopify_function_macro/src/lib.rs index c382691..a36c8bc 100644 --- a/shopify_function_macro/src/lib.rs +++ b/shopify_function_macro/src/lib.rs @@ -272,6 +272,18 @@ pub fn typegen( module.to_token_stream().into() } +/// Helper function to determine if a GraphQL input field type is nullable +/// Uses conservative detection to identify Optional fields from GraphQL schema +fn is_input_field_nullable(ivd: &impl InputValueDefinition) -> bool { + // Use std::any::type_name to get type information as a string + let type_name = std::any::type_name_of_val(&ivd.r#type()); + + // only treat fields that are explicitly nullable as Option types + // This prevents incorrectly wrapping required fields in Option + type_name.contains("Option") + || (type_name.contains("Nullable") && !type_name.contains("NonNull")) +} + struct ShopifyFunctionCodeGenerator; impl CodeGenerator for ShopifyFunctionCodeGenerator { @@ -496,6 +508,10 @@ impl CodeGenerator for ShopifyFunctionCodeGenerator { ) -> Vec { let name_ident = names::type_ident(input_object_type_definition.name()); + // Conditionally serialize fields based on GraphQL schema nullability + // Nullable fields (Option) are only serialized if Some(_) + // Required fields are always serialized + let field_statements: Vec = input_object_type_definition .input_field_definitions() .iter() @@ -503,28 +519,65 @@ impl CodeGenerator for ShopifyFunctionCodeGenerator { let field_name_ident = names::field_ident(ivd.name()); let field_name_lit_str = syn::LitStr::new(ivd.name(), Span::mixed_site()); - vec![ + // Check if this field is nullable in the GraphQL schema + if is_input_field_nullable(ivd) { + // For nullable fields, only serialize if Some(_) + vec![parse_quote! { + if let ::std::option::Option::Some(ref value) = self.#field_name_ident { + context.write_utf8_str(#field_name_lit_str)?; + value.serialize(context)?; + } + }] + } else { + // For required fields, always serialize + vec![ + parse_quote! { + context.write_utf8_str(#field_name_lit_str)?; + }, + parse_quote! { + self.#field_name_ident.serialize(context)?; + }, + ] + } + }) + .collect(); + + // Generate field counting statements for dynamic field count calculation + let field_count_statements: Vec = input_object_type_definition + .input_field_definitions() + .iter() + .map(|ivd| { + let field_name_ident = names::field_ident(ivd.name()); + + if is_input_field_nullable(ivd) { + // For nullable fields, count only if Some(_) parse_quote! { - context.write_utf8_str(#field_name_lit_str)?; - }, + if let ::std::option::Option::Some(_) = self.#field_name_ident { + field_count += 1; + } + } + } else { + // For required fields, always count parse_quote! { - self.#field_name_ident.serialize(context)?; - }, - ] + field_count += 1; + } + } }) .collect(); - let num_fields = input_object_type_definition.input_field_definitions().len(); - let serialize_impl = parse_quote! { impl shopify_function::wasm_api::Serialize for #name_ident { fn serialize(&self, context: &mut shopify_function::wasm_api::Context) -> ::std::result::Result<(), shopify_function::wasm_api::write::Error> { + // Calculate dynamic field count based on non-null fields + let mut field_count = 0usize; + #(#field_count_statements)* + context.write_object( |context| { #(#field_statements)* ::std::result::Result::Ok(()) }, - #num_fields, + field_count, ) } }