@@ -47,7 +47,7 @@ use crate::metrics::*;
4747
4848/// Try to write multiple output to the writer if possible.
4949pub async fn write_output < W : AsyncWrite + Send + Sync + Unpin > (
50- w : QueryResultWriter < ' _ , W > ,
50+ mut writer : QueryResultWriter < ' _ , W > ,
5151 query_context : QueryContextRef ,
5252 session : SessionRef ,
5353 outputs : Vec < Result < Output > > ,
@@ -56,21 +56,87 @@ pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
5656 session. add_warning ( warning) ;
5757 }
5858
59- let mut writer = Some ( MysqlResultWriter :: new (
60- w,
61- query_context. clone ( ) ,
62- session. clone ( ) ,
63- ) ) ;
59+ enum Response {
60+ ResultSet {
61+ columns : Vec < Column > ,
62+ stream : SendableRecordBatchStream ,
63+ } ,
64+ AffectedRows ( usize ) ,
65+ }
66+
67+ let mut responses = Vec :: with_capacity ( outputs. len ( ) ) ;
6468 for output in outputs {
65- let result_writer = writer. take ( ) . context ( error:: InternalSnafu {
66- err_msg : "Sending multiple result set is unsupported" ,
67- } ) ?;
68- writer = result_writer. try_write_one ( output) . await ?;
69+ match output {
70+ Ok ( x) => {
71+ let output = match x. data {
72+ OutputData :: Stream ( stream) => either:: Left ( stream) ,
73+ OutputData :: RecordBatches ( record_batches) => {
74+ either:: Left ( record_batches. as_stream ( ) )
75+ }
76+ OutputData :: AffectedRows ( rows) => either:: Right ( rows) ,
77+ } ;
78+ responses. push ( match output {
79+ either:: Left ( stream) => {
80+ let schema = stream. schema ( ) ;
81+ let columns = match create_mysql_column_def ( & schema) {
82+ Ok ( columns) => columns,
83+ Err ( e) => {
84+ MysqlResultWriter :: write_query_error (
85+ e,
86+ writer,
87+ query_context. clone ( ) ,
88+ )
89+ . await ?;
90+ return Ok ( ( ) ) ;
91+ }
92+ } ;
93+ Response :: ResultSet { columns, stream }
94+ }
95+ either:: Right ( rows) => Response :: AffectedRows ( rows) ,
96+ } ) ;
97+ }
98+ Err ( e) => {
99+ MysqlResultWriter :: write_query_error ( e, writer, query_context. clone ( ) ) . await ?;
100+ return Ok ( ( ) ) ;
101+ }
102+ }
69103 }
70104
71- if let Some ( result_writer) = writer {
72- result_writer. finish ( ) . await ?;
105+ for response in & mut responses {
106+ writer = match response {
107+ Response :: ResultSet { columns, stream } => {
108+ let mut row_writer = writer. start ( columns) . await ?;
109+ while let Some ( record_batch) = stream. next ( ) . await {
110+ match record_batch {
111+ Ok ( record_batch) => {
112+ if let Err ( e) = MysqlResultWriter :: write_recordbatch (
113+ & mut row_writer,
114+ record_batch,
115+ query_context. clone ( ) ,
116+ )
117+ . await
118+ {
119+ let ( kind, err) = handle_err ( e, query_context) ;
120+ row_writer. finish_error ( kind, & err. as_bytes ( ) ) . await ?;
121+ return Ok ( ( ) ) ;
122+ }
123+ }
124+ Err ( e) => {
125+ let ( kind, err) = handle_err ( e, query_context) ;
126+ row_writer. finish_error ( kind, & err. as_bytes ( ) ) . await ?;
127+ return Ok ( ( ) ) ;
128+ }
129+ }
130+ }
131+ row_writer. finish_one ( ) . await ?
132+ }
133+ Response :: AffectedRows ( rows) => {
134+ MysqlResultWriter :: write_affected_rows ( writer, * rows, & session) . await ?
135+ }
136+ }
73137 }
138+
139+ writer. no_more_results ( ) . await ?;
74140 Ok ( ( ) )
75141}
76142
@@ -97,75 +163,10 @@ pub fn handle_err(e: impl ErrorExt, query_ctx: QueryContextRef) -> (ErrorKind, S
97163 ( kind, err_msg)
98164}
99165
100- struct QueryResult {
101- schema : SchemaRef ,
102- stream : SendableRecordBatchStream ,
103- }
104-
105- pub struct MysqlResultWriter < ' a , W : AsyncWrite + Unpin > {
106- writer : QueryResultWriter < ' a , W > ,
107- query_context : QueryContextRef ,
108- session : SessionRef ,
109- }
110-
111- impl < ' a , W : AsyncWrite + Unpin > MysqlResultWriter < ' a , W > {
112- pub fn new (
113- writer : QueryResultWriter < ' a , W > ,
114- query_context : QueryContextRef ,
115- session : SessionRef ,
116- ) -> MysqlResultWriter < ' a , W > {
117- MysqlResultWriter :: < ' a , W > {
118- writer,
119- query_context,
120- session,
121- }
122- }
123-
124- /// Try to write one result set. If there are more than one result set, return `Some`.
125- pub async fn try_write_one (
126- self ,
127- output : Result < Output > ,
128- ) -> io:: Result < Option < MysqlResultWriter < ' a , W > > > {
129- // We don't support sending multiple query result because the RowWriter's lifetime is bound to
130- // a local variable.
131- match output {
132- Ok ( output) => match output. data {
133- OutputData :: Stream ( stream) => {
134- let query_result = QueryResult {
135- schema : stream. schema ( ) ,
136- stream,
137- } ;
138- Self :: write_query_result ( query_result, self . writer , self . query_context ) . await ?;
139- }
140- OutputData :: RecordBatches ( recordbatches) => {
141- let query_result = QueryResult {
142- schema : recordbatches. schema ( ) ,
143- stream : recordbatches. as_stream ( ) ,
144- } ;
145- Self :: write_query_result ( query_result, self . writer , self . query_context ) . await ?;
146- }
147- OutputData :: AffectedRows ( rows) => {
148- let next_writer =
149- Self :: write_affected_rows ( self . writer , rows, & self . session ) . await ?;
150- return Ok ( Some ( MysqlResultWriter :: new (
151- next_writer,
152- self . query_context ,
153- self . session ,
154- ) ) ) ;
155- }
156- } ,
157- Err ( error) => Self :: write_query_error ( error, self . writer , self . query_context ) . await ?,
158- }
159- Ok ( None )
160- }
166+ struct MysqlResultWriter ;
161167
162- /// Indicate no more result set to write. No need to call this if there is only one result set.
163- pub async fn finish ( self ) -> Result < ( ) > {
164- self . writer . no_more_results ( ) . await ?;
165- Ok ( ( ) )
166- }
167-
168- async fn write_affected_rows (
168+ impl MysqlResultWriter {
169+ async fn write_affected_rows < ' a , W : AsyncWrite + Unpin > (
169170 w : QueryResultWriter < ' a , W > ,
170171 rows : usize ,
171172 session : & SessionRef ,
@@ -182,54 +183,12 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
182183 Ok ( next_writer)
183184 }
184185
185- async fn write_query_result (
186- mut query_result : QueryResult ,
187- writer : QueryResultWriter < ' a , W > ,
188- query_context : QueryContextRef ,
189- ) -> io:: Result < ( ) > {
190- match create_mysql_column_def ( & query_result. schema ) {
191- Ok ( column_def) => {
192- // The RowWriter's lifetime is bound to `column_def` thus we can't use finish_one()
193- // to return a new QueryResultWriter.
194- let mut row_writer = writer. start ( & column_def) . await ?;
195- while let Some ( record_batch) = query_result. stream . next ( ) . await {
196- match record_batch {
197- Ok ( record_batch) => {
198- if let Err ( e) = Self :: write_recordbatch (
199- & mut row_writer,
200- record_batch,
201- query_context. clone ( ) ,
202- & query_result. schema ,
203- )
204- . await
205- {
206- let ( kind, err) = handle_err ( e, query_context) ;
207- row_writer. finish_error ( kind, & err. as_bytes ( ) ) . await ?;
208- return Ok ( ( ) ) ;
209- }
210- }
211- Err ( e) => {
212- let ( kind, err) = handle_err ( e, query_context) ;
213- debug ! ( "Failed to get result, kind: {:?}, err: {}" , kind, err) ;
214- row_writer. finish_error ( kind, & err. as_bytes ( ) ) . await ?;
215-
216- return Ok ( ( ) ) ;
217- }
218- }
219- }
220- row_writer. finish ( ) . await ?;
221- Ok ( ( ) )
222- }
223- Err ( error) => Self :: write_query_error ( error, writer, query_context) . await ,
224- }
225- }
226-
227- async fn write_recordbatch (
228- row_writer : & mut RowWriter < ' _ , W > ,
186+ async fn write_recordbatch < W : AsyncWrite + Unpin > (
187+ row_writer : & mut RowWriter < ' _ , ' _ , W > ,
229188 record_batch : RecordBatch ,
230189 query_context : QueryContextRef ,
231- schema : & SchemaRef ,
232190 ) -> Result < ( ) > {
191+ let schema = record_batch. schema . clone ( ) ;
233192 let record_batch = record_batch. into_df_record_batch ( ) ;
234193 for i in 0 ..record_batch. num_rows ( ) {
235194 for ( j, column) in record_batch. columns ( ) . iter ( ) . enumerate ( ) {
@@ -358,7 +317,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
358317 Ok ( ( ) )
359318 }
360319
361- async fn write_query_error (
320+ async fn write_query_error < ' a , W : AsyncWrite + Unpin > (
362321 error : impl ErrorExt ,
363322 w : QueryResultWriter < ' a , W > ,
364323 query_context : QueryContextRef ,
0 commit comments