diff --git a/v3/Cargo.lock b/v3/Cargo.lock index b9daefc9e7567..51298fe1e9122 100644 --- a/v3/Cargo.lock +++ b/v3/Cargo.lock @@ -4505,6 +4505,7 @@ dependencies = [ "metadata-resolve", "ndc-models 0.2.0", "open-dds", + "reqwest", "schema", "serde", "serde_json", diff --git a/v3/crates/engine/bin/engine/main.rs b/v3/crates/engine/bin/engine/main.rs index cd84291a14d42..98af6f9f854c6 100644 --- a/v3/crates/engine/bin/engine/main.rs +++ b/v3/crates/engine/bin/engine/main.rs @@ -667,6 +667,7 @@ async fn pre_execution_plugins_middleware<'a>( /// Handle a SQL request and execute it. async fn handle_sql_request( + headers: axum::http::header::HeaderMap, State(state): State>, Extension(session): Extension, Json(request): Json, @@ -680,6 +681,7 @@ async fn handle_sql_request( || { Box::pin(async { sql::execute::execute_sql( + Arc::new(headers), state.sql_context.clone(), Arc::new(session), Arc::new(state.http_context.clone()), diff --git a/v3/crates/engine/tests/common.rs b/v3/crates/engine/tests/common.rs index fdeb715b8fc69..1ceba55e399c2 100644 --- a/v3/crates/engine/tests/common.rs +++ b/v3/crates/engine/tests/common.rs @@ -563,6 +563,7 @@ pub(crate) fn test_sql(test_path_string: &str) -> anyhow::Result<()> { let request_path = test_path.join("query.sql"); let request_path_json = test_path.join("query.json"); + let headers_path_json = test_path.join("headers.json"); let response_path = test_path_string.to_string() + "/expected.json"; let explain_path = test_path_string.to_string() + "/plan.json"; let metadata_path = root_test_dir.join("sql/metadata.json"); @@ -590,6 +591,13 @@ pub(crate) fn test_sql(test_path_string: &str) -> anyhow::Result<()> { serde_json::from_str(&json_content)? }; + let header_map = if let Ok(content) = fs::read_to_string(headers_path_json) { + let header_map: HashMap = serde_json::from_str(&content)?; + Arc::new(reqwest::header::HeaderMap::try_from(&header_map)?) + } else { + Arc::new(reqwest::header::HeaderMap::new()) + }; + let session = Arc::new({ let session_vars_path = &test_path.join("session_variables.json"); let session_variables: HashMap = @@ -608,6 +616,7 @@ pub(crate) fn test_sql(test_path_string: &str) -> anyhow::Result<()> { &http_context, &mut test_ctx.mint, explain_path, + &header_map, &SqlRequest::new(format!("EXPLAIN {}", request.sql)), ) .await?; @@ -618,6 +627,7 @@ pub(crate) fn test_sql(test_path_string: &str) -> anyhow::Result<()> { &http_context, &mut test_ctx.mint, response_path, + &header_map, &request, ) .await?; @@ -632,9 +642,11 @@ async fn snapshot_sql( http_context: &Arc, mint: &mut Mint, response_path: String, + request_headers: &Arc, request: &SqlRequest, ) -> Result<(), anyhow::Error> { let response = sql::execute::execute_sql( + request_headers.clone(), catalog.clone(), session.clone(), http_context.clone(), diff --git a/v3/crates/engine/tests/sql.rs b/v3/crates/engine/tests/sql.rs index 5d4e7c2048fc2..98c7f72469afb 100644 --- a/v3/crates/engine/tests/sql.rs +++ b/v3/crates/engine/tests/sql.rs @@ -31,6 +31,11 @@ fn test_commands_functions() -> anyhow::Result<()> { test_sql("sql/commands/functions") } +#[test] +fn test_commands_functions_forward_headers() -> anyhow::Result<()> { + test_sql("sql/commands/functions_forward_headers") +} + #[test] fn test_commands_functions_empty_args() -> anyhow::Result<()> { test_sql("sql/commands/functions_empty_args") diff --git a/v3/crates/engine/tests/sql/commands/functions_forward_headers/expected.json b/v3/crates/engine/tests/sql/commands/functions_forward_headers/expected.json new file mode 100644 index 0000000000000..13b44a240187f --- /dev/null +++ b/v3/crates/engine/tests/sql/commands/functions_forward_headers/expected.json @@ -0,0 +1,6 @@ +[ + { + "token": "64a6c518-4a5b-4067-a99f-3abc11eeeacf", + "expiry": "2025-12-12T05:48:33+0000" + } +] diff --git a/v3/crates/engine/tests/sql/commands/functions_forward_headers/headers.json b/v3/crates/engine/tests/sql/commands/functions_forward_headers/headers.json new file mode 100644 index 0000000000000..8e40fc85df1ce --- /dev/null +++ b/v3/crates/engine/tests/sql/commands/functions_forward_headers/headers.json @@ -0,0 +1,3 @@ +{ + "authorization": "foo" +} diff --git a/v3/crates/engine/tests/sql/commands/functions_forward_headers/plan.json b/v3/crates/engine/tests/sql/commands/functions_forward_headers/plan.json new file mode 100644 index 0000000000000..ec77e2dae7300 --- /dev/null +++ b/v3/crates/engine/tests/sql/commands/functions_forward_headers/plan.json @@ -0,0 +1,10 @@ +[ + { + "plan_type": "logical_plan", + "plan": "TableScan: tmp_table projection=[token, expiry]" + }, + { + "plan_type": "physical_plan", + "plan": "NDCFunctionPushDown\n" + } +] diff --git a/v3/crates/engine/tests/sql/commands/functions_forward_headers/query.sql b/v3/crates/engine/tests/sql/commands/functions_forward_headers/query.sql new file mode 100644 index 0000000000000..93d5be5b9baf9 --- /dev/null +++ b/v3/crates/engine/tests/sql/commands/functions_forward_headers/query.sql @@ -0,0 +1,5 @@ +SELECT + * +FROM + get_session_info(STRUCT(1 as "userId")); + diff --git a/v3/crates/engine/tests/sql/commands/functions_forward_headers/session_variables.json b/v3/crates/engine/tests/sql/commands/functions_forward_headers/session_variables.json new file mode 100644 index 0000000000000..939ffd9a413db --- /dev/null +++ b/v3/crates/engine/tests/sql/commands/functions_forward_headers/session_variables.json @@ -0,0 +1,3 @@ +{ + "x-hasura-role": "admin" +} diff --git a/v3/crates/engine/tests/sql/introspection/functions/expected.json b/v3/crates/engine/tests/sql/introspection/functions/expected.json index e9ad4d2d08933..5951edc610b62 100644 --- a/v3/crates/engine/tests/sql/introspection/functions/expected.json +++ b/v3/crates/engine/tests/sql/introspection/functions/expected.json @@ -65,6 +65,21 @@ } ] }, + { + "function_name": "get_session_info", + "return_type": "SessionInfo", + "description": null, + "arguments": [ + { + "name": "userId", + "position": 0, + "argument_type": "INT32", + "argument_type_normalized": "INT32", + "is_nullable": false, + "description": null + } + ] + }, { "function_name": "uppercase_actor_name_by_id", "return_type": "actor", diff --git a/v3/crates/engine/tests/sql/introspection/struct_types/expected.json b/v3/crates/engine/tests/sql/introspection/struct_types/expected.json index f6f6a50c9e3ee..ae53ece62a693 100644 --- a/v3/crates/engine/tests/sql/introspection/struct_types/expected.json +++ b/v3/crates/engine/tests/sql/introspection/struct_types/expected.json @@ -109,6 +109,26 @@ } ] }, + { + "name": "SessionInfo", + "description": null, + "fields": [ + { + "field_name": "expiry", + "field_type": "STRING", + "field_type_normalized": "STRING", + "is_nullable": false, + "description": null + }, + { + "field_name": "token", + "field_type": "STRING", + "field_type_normalized": "STRING", + "is_nullable": false, + "description": null + } + ] + }, { "name": "actor", "description": null, diff --git a/v3/crates/engine/tests/sql/metadata.json b/v3/crates/engine/tests/sql/metadata.json index 77f1ee331e1cc..7b10249ab94e5 100644 --- a/v3/crates/engine/tests/sql/metadata.json +++ b/v3/crates/engine/tests/sql/metadata.json @@ -13192,6 +13192,109 @@ "rootFieldKind": "Mutation" } } + }, + { + "kind": "ObjectType", + "version": "v1", + "definition": { + "name": "SessionInfo", + "fields": [ + { + "name": "token", + "type": "String!" + }, + { + "name": "expiry", + "type": "String!" + } + ], + "graphql": { + "typeName": "SessionInfo" + }, + "dataConnectorTypeMapping": [ + { + "dataConnectorName": "custom", + "dataConnectorObjectType": "session_info", + "fieldMapping": { + "token": { + "column": { + "name": "token" + } + }, + "expiry": { + "column": { + "name": "expiry" + } + } + } + } + ] + } + }, + { + "kind": "TypePermissions", + "version": "v1", + "definition": { + "typeName": "SessionInfo", + "permissions": [ + { + "role": "admin", + "output": { + "allowedFields": ["token", "expiry"] + } + }, + { + "role": "user", + "output": { + "allowedFields": ["token", "expiry"] + } + } + ] + } + }, + { + "kind": "Command", + "version": "v1", + "definition": { + "name": "get_session_info", + "arguments": [ + { + "name": "userId", + "type": "Int!" + } + ], + "outputType": "SessionInfo", + "source": { + "dataConnectorName": "custom", + "dataConnectorCommand": { + "function": "get_session_details" + }, + "argumentMapping": { + "userId": "user_id" + } + }, + "graphql": { + "rootFieldName": "getSessionInfo", + "rootFieldKind": "Query" + } + } + }, + { + "kind": "CommandPermissions", + "version": "v1", + "definition": { + "commandName": "get_session_info", + "permissions": [ + { + "role": "admin", + "allowExecution": true + }, + { + "role": "user", + "allowExecution": true + } + ] + } } ] } diff --git a/v3/crates/ir/src/arguments.rs b/v3/crates/ir/src/arguments.rs index 8dc770aaeb470..c4c1e73c57ecc 100644 --- a/v3/crates/ir/src/arguments.rs +++ b/v3/crates/ir/src/arguments.rs @@ -161,6 +161,25 @@ where } } + // preset arguments from `DataConnectorLink` argument presets + for (argument_name, value) in process_connector_link_presets( + data_connector_link_argument_presets, + session_variables, + request_headers, + )? { + arguments.insert(argument_name, Argument::Literal { value }); + } + + Ok(arguments) +} + +/// Builds arguments for a command that come from a connector link's argument presets +pub fn process_connector_link_presets( + data_connector_link_argument_presets: &BTreeMap, + session_variables: &SessionVariables, + request_headers: &reqwest::header::HeaderMap, +) -> Result, error::Error> { + let mut arguments = BTreeMap::new(); // preset arguments from `DataConnectorLink` argument presets for (dc_argument_preset_name, dc_argument_preset_value) in data_connector_link_argument_presets { @@ -199,12 +218,9 @@ where arguments.insert( dc_argument_preset_name.clone(), - Argument::Literal { - value: serde_json::to_value(SerializableHeaderMap(headers_argument))?, - }, + serde_json::to_value(SerializableHeaderMap(headers_argument))?, ); } - Ok(arguments) } diff --git a/v3/crates/ir/src/lib.rs b/v3/crates/ir/src/lib.rs index 654d9b9213325..410d69c23a457 100644 --- a/v3/crates/ir/src/lib.rs +++ b/v3/crates/ir/src/lib.rs @@ -25,7 +25,7 @@ pub use remote_joins::VariableName; pub use aggregates::{ mk_alias_from_graphql_field_path, AggregateFieldSelection, AggregateSelectionSet, }; -pub use arguments::Argument; +pub use arguments::{process_connector_link_presets, Argument}; pub use commands::{CommandInfo, FunctionBasedCommand, ProcedureBasedCommand}; pub use filter::expression::{ ComparisonTarget, ComparisonValue, Expression, LocalFieldComparison, RelationshipColumnMapping, diff --git a/v3/crates/sql/Cargo.toml b/v3/crates/sql/Cargo.toml index f9f63527b195e..edce5d0f11536 100644 --- a/v3/crates/sql/Cargo.toml +++ b/v3/crates/sql/Cargo.toml @@ -23,6 +23,7 @@ serde = { workspace = true, features = ["rc"] } serde_json = { workspace = true } thiserror = { workspace = true } chrono = { workspace = true } +reqwest = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/v3/crates/sql/src/catalog.rs b/v3/crates/sql/src/catalog.rs index c03f73430522a..0535547868c08 100644 --- a/v3/crates/sql/src/catalog.rs +++ b/v3/crates/sql/src/catalog.rs @@ -137,6 +137,7 @@ impl datafusion::CatalogProvider for model::WithSession { impl Catalog { pub fn create_session_context( self: Arc, + request_headers: &Arc, session: &Arc, http_context: &Arc, ) -> datafusion::SessionContext { @@ -174,6 +175,7 @@ impl Catalog { catalog: self.clone(), session: session.clone(), http_context: http_context.clone(), + request_headers: request_headers.clone(), }); let session_state = datafusion::SessionStateBuilder::new() .with_config(session_config) diff --git a/v3/crates/sql/src/execute.rs b/v3/crates/sql/src/execute.rs index 37f72dc45ff7e..e31b1c682d44c 100644 --- a/v3/crates/sql/src/execute.rs +++ b/v3/crates/sql/src/execute.rs @@ -73,6 +73,7 @@ impl TraceableError for SqlExecutionError { /// Executes an SQL Request using the Apache DataFusion query engine. pub async fn execute_sql( + request_headers: Arc, catalog: Arc, session: Arc, http_context: Arc, @@ -85,7 +86,8 @@ pub async fn execute_sql( "Create a datafusion SessionContext", SpanVisibility::Internal, || { - let session = catalog.create_session_context(&session, &http_context); + let session = + catalog.create_session_context(&request_headers, &session, &http_context); Successful::new(session) }, ) diff --git a/v3/crates/sql/src/execute/planner.rs b/v3/crates/sql/src/execute/planner.rs index b85c97a73ca33..d6b05cd11f7ed 100644 --- a/v3/crates/sql/src/execute/planner.rs +++ b/v3/crates/sql/src/execute/planner.rs @@ -20,6 +20,7 @@ use datafusion::{ use async_trait::async_trait; pub(crate) struct OpenDDQueryPlanner { + pub(crate) request_headers: Arc, pub(crate) session: Arc, pub(crate) http_context: Arc, pub(crate) catalog: Arc, @@ -37,6 +38,7 @@ impl QueryPlanner for OpenDDQueryPlanner { // Teach the default physical planner how to plan TopK nodes. let physical_planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(NDCPushDownPlanner { + request_headers: self.request_headers.clone(), session: self.session.clone(), http_context: self.http_context.clone(), catalog: self.catalog.clone(), @@ -49,6 +51,7 @@ impl QueryPlanner for OpenDDQueryPlanner { } pub(crate) struct NDCPushDownPlanner { + pub(crate) request_headers: Arc, pub(crate) session: Arc, pub(crate) http_context: Arc, pub(crate) catalog: Arc, @@ -78,6 +81,7 @@ impl ExtensionPlanner for NDCPushDownPlanner { assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs"); assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs"); build_execution_plan( + &self.request_headers, &self.catalog.metadata, &self.http_context, &self.session, diff --git a/v3/crates/sql/src/execute/planner/command/physical.rs b/v3/crates/sql/src/execute/planner/command/physical.rs index f071c96b08b9c..91c98e82afac8 100644 --- a/v3/crates/sql/src/execute/planner/command/physical.rs +++ b/v3/crates/sql/src/execute/planner/command/physical.rs @@ -19,6 +19,7 @@ use open_dds::commands::DataConnectorCommand; pub(crate) use procedure::NDCProcedurePushDown; pub fn build_execution_plan( + request_headers: &reqwest::header::HeaderMap, metadata: &metadata_resolve::Metadata, http_context: &Arc, session: &Arc, @@ -138,6 +139,26 @@ pub fn build_execution_plan( ndc_fields.insert(NdcFieldAlias::from(field_alias.as_str()), ndc_field); } + let (ndc_fields, extract_response_from) = match &command_source.data_connector.response_config { + // if the data connector has 'responseHeaders' configured, we'll need to wrap the ndc fields + // under the 'result' field if the command's response at opendd layer refers to the 'result' + // field's type. Note that we aren't requesting the 'header's field as we don't forward the + // response headers in the SQL layer yet + Some(response_config) if !command_source.ndc_type_opendd_type_same => { + let result_field_name = NdcFieldAlias::from(response_config.result_field.as_str()); + let result_field = ResolvedField::Column { + column: response_config.result_field.clone(), + fields: Some(execute::plan::field::NestedField::Object( + execute::plan::field::NestedObject { fields: ndc_fields }, + )), + arguments: BTreeMap::new(), + }; + let fields = IndexMap::from_iter([(result_field_name, result_field)]); + (fields, Some(response_config.result_field.clone())) + } + _ => (ndc_fields, None), + }; + if !command .permissions .get(&session.role) @@ -161,6 +182,17 @@ pub fn build_execution_plan( ndc_arguments.insert(ndc_argument_name.clone(), ndc_argument_value.clone()); } + // preset arguments from `DataConnectorLink` argument presets + for (argument_name, value) in ir::process_connector_link_presets( + &command_source.data_connector_link_argument_presets, + &session.variables, + request_headers, + ) + .map_err(|e| DataFusionError::External(Box::new(e)))? + { + ndc_arguments.insert(argument_name, value); + } + match &command_source.source { DataConnectorCommand::Function(function_name) => { let ndc_pushdown = NDCFunctionPushDown::new( @@ -171,6 +203,7 @@ pub fn build_execution_plan( ndc_fields, schema, output.clone(), + extract_response_from, ); Ok(Arc::new(ndc_pushdown)) } @@ -183,6 +216,7 @@ pub fn build_execution_plan( ndc_fields, schema, output.clone(), + extract_response_from, ); Ok(Arc::new(ndc_pushdown)) } diff --git a/v3/crates/sql/src/execute/planner/command/physical/function.rs b/v3/crates/sql/src/execute/planner/command/physical/function.rs index 5b89498c68261..9f7df56156e55 100644 --- a/v3/crates/sql/src/execute/planner/command/physical/function.rs +++ b/v3/crates/sql/src/execute/planner/command/physical/function.rs @@ -28,7 +28,7 @@ use execute::{ use ir::{NdcFieldAlias, NdcRelationshipName}; use open_dds::{ commands::FunctionName, - data_connector::CollectionName, + data_connector::{CollectionName, DataConnectorColumnName}, types::{CustomTypeName, DataConnectorArgumentName}, }; use tracing_util::{FutureExt, SpanVisibility, TraceableError}; @@ -71,6 +71,9 @@ pub(crate) struct NDCFunctionPushDown { collection_relationships: BTreeMap, // used to post process a command's output output: CommandOutput, + // The key from which the response has to be extracted if the command output is not the + // same as the ndc type. This happens with response config in a data connector link + extract_response_from: Option, data_connector: Arc, // the schema of the node's output projected_schema: SchemaRef, @@ -117,6 +120,7 @@ impl NDCFunctionPushDown { // schema of the output of the command selection schema: &DFSchemaRef, output: CommandOutput, + extract_response_from: Option, ) -> NDCFunctionPushDown { let metrics = ExecutionPlanMetricsSet::new(); Self { @@ -126,6 +130,7 @@ impl NDCFunctionPushDown { fields: wrap_ndc_fields(&output, ndc_fields), collection_relationships: BTreeMap::new(), output, + extract_response_from, data_connector, projected_schema: schema.inner().clone(), cache: Self::compute_properties(schema.inner().clone()), @@ -232,6 +237,7 @@ impl ExecutionPlan for NDCFunctionPushDown { query_request, self.data_connector.clone(), self.output.clone(), + self.extract_response_from.clone(), baseline_metrics, ) .with_context((*otel_cx).clone()) @@ -244,12 +250,13 @@ impl ExecutionPlan for NDCFunctionPushDown { } } -pub async fn fetch_from_data_connector( +async fn fetch_from_data_connector( schema: SchemaRef, http_context: Arc, query_request: execute::ndc::NdcQueryRequest, data_connector: Arc, output: CommandOutput, + extract_response_from: Option, baseline_metrics: BaselineMetrics, ) -> Result { let tracer = tracing_util::global_tracer(); @@ -261,15 +268,38 @@ pub async fn fetch_from_data_connector( "ndc_response_to_record_batch", "Converts NDC Response into datafusion's RecordBatch", SpanVisibility::Internal, - || ndc_response_to_record_batch(schema, ndc_response, &output, &baseline_metrics), + || { + ndc_response_to_record_batch( + schema, + ndc_response, + &output, + extract_response_from.as_ref(), + &baseline_metrics, + ) + }, )?; Ok(batch) } -pub fn ndc_response_to_record_batch( +pub(super) fn extract_result_field( + mut rows: serde_json::Value, + extract_response_from: Option<&DataConnectorColumnName>, +) -> Result { + match extract_response_from { + Some(result_field) => rows + .as_object_mut() + .ok_or_else(|| "expecting an object to extract result field".to_string())? + .remove(result_field.as_str()) + .ok_or_else(|| format!("missing result field in ndc response: {result_field}")), + None => Ok(rows), + } +} + +fn ndc_response_to_record_batch( schema: SchemaRef, ndc_response: NdcQueryResponse, output: &CommandOutput, + extract_response_from: Option<&DataConnectorColumnName>, baseline_metrics: &BaselineMetrics, ) -> Result { let rows = ndc_response @@ -292,6 +322,9 @@ pub fn ndc_response_to_record_batch( })? .0; + let rows = extract_result_field(rows, extract_response_from) + .map_err(ExecutionPlanError::NDCResponseFormat)?; + let mut decoder = datafusion::arrow::json::reader::ReaderBuilder::new(schema.clone()).build_decoder()?; match output { diff --git a/v3/crates/sql/src/execute/planner/command/physical/procedure.rs b/v3/crates/sql/src/execute/planner/command/physical/procedure.rs index de9ed41411fa9..1666ea01d9a33 100644 --- a/v3/crates/sql/src/execute/planner/command/physical/procedure.rs +++ b/v3/crates/sql/src/execute/planner/command/physical/procedure.rs @@ -24,12 +24,15 @@ use execute::{ HttpContext, }; use ir::{NdcFieldAlias, NdcRelationshipName}; -use open_dds::{commands::ProcedureName, types::DataConnectorArgumentName}; +use open_dds::{ + commands::ProcedureName, data_connector::DataConnectorColumnName, + types::DataConnectorArgumentName, +}; use tracing_util::{FutureExt, SpanVisibility, TraceableError}; use crate::execute::planner::common::PhysicalPlanOptions; -use super::CommandOutput; +use super::{function::extract_result_field, CommandOutput}; #[derive(Debug, thiserror::Error)] pub enum ExecutionPlanError { @@ -69,6 +72,9 @@ pub(crate) struct NDCProcedurePushDown { collection_relationships: BTreeMap, // used to post process a command's output output: CommandOutput, + // The key from which the response has to be extracted if the command output is not the + // same as the ndc type. This happens with response config in a data connector link + extract_response_from: Option, data_connector: Arc, // the schema of the node's output projected_schema: SchemaRef, @@ -105,6 +111,7 @@ impl NDCProcedurePushDown { // schema of the output of the command selection schema: &DFSchemaRef, output: CommandOutput, + extract_response_from: Option, ) -> NDCProcedurePushDown { let metrics = ExecutionPlanMetricsSet::new(); Self { @@ -114,6 +121,7 @@ impl NDCProcedurePushDown { fields: Some(wrap_ndc_fields(&output, ndc_fields)), collection_relationships: BTreeMap::new(), output, + extract_response_from, data_connector, projected_schema: schema.inner().clone(), cache: Self::compute_properties(schema.inner().clone()), @@ -220,6 +228,7 @@ impl ExecutionPlan for NDCProcedurePushDown { query_request, self.data_connector.clone(), self.output.clone(), + self.extract_response_from.clone(), baseline_metrics, ) .with_context((*otel_cx).clone()) @@ -238,6 +247,7 @@ pub async fn fetch_from_data_connector( request: execute::ndc::NdcMutationRequest, data_connector: Arc, output: CommandOutput, + extract_response_from: Option, baseline_metrics: BaselineMetrics, ) -> Result { let tracer = tracing_util::global_tracer(); @@ -253,7 +263,15 @@ pub async fn fetch_from_data_connector( "ndc_response_to_record_batch", "Converts NDC Response into datafusion's RecordBatch", SpanVisibility::Internal, - || ndc_response_to_record_batch(schema, ndc_response, &output, &baseline_metrics), + || { + ndc_response_to_record_batch( + schema, + ndc_response, + &output, + extract_response_from.as_ref(), + &baseline_metrics, + ) + }, )?; Ok(batch) } @@ -262,6 +280,7 @@ pub fn ndc_response_to_record_batch( schema: SchemaRef, ndc_response: NdcMutationResponse, output: &CommandOutput, + extract_response_from: Option<&DataConnectorColumnName>, baseline_metrics: &BaselineMetrics, ) -> Result { let ndc_models::MutationOperationResults::Procedure { result } = ndc_response @@ -272,6 +291,9 @@ pub fn ndc_response_to_record_batch( ExecutionPlanError::NDCResponseFormat("no operation_results found".to_string()) })?; + let result = extract_result_field(result, extract_response_from) + .map_err(ExecutionPlanError::NDCResponseFormat)?; + let mut decoder = datafusion::arrow::json::reader::ReaderBuilder::new(schema.clone()).build_decoder()?; match output {