Skip to content

Commit c05a5ca

Browse files
MichaelScofieldWenyXu
authored andcommitted
fix: mysql prepare correctly returns error instead of panic (GreptimeTeam#7963)
feat: mysql writer support multiple statement execution Signed-off-by: luofucong <luofc@foxmail.com> Signed-off-by: WenyXu <wenymedia@gmail.com>
1 parent 467cd70 commit c05a5ca

9 files changed

Lines changed: 267 additions & 142 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/servers/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ datafusion-expr.workspace = true
6363
datafusion-pg-catalog.workspace = true
6464
datatypes.workspace = true
6565
derive_builder.workspace = true
66+
either.workspace = true
6667
futures.workspace = true
6768
futures-util.workspace = true
6869
headers = "0.4"
@@ -84,7 +85,8 @@ notify.workspace = true
8485
object-pool = "0.5"
8586
once_cell.workspace = true
8687
openmetrics-parser = "0.4"
87-
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", tag = "v0.10.0" }
88+
# Wait for https://github.com/databendlabs/opensrv/pull/81
89+
opensrv-mysql = { git = "https://github.com/GreptimeTeam/opensrv", rev = "6c5a451544194b7bb60a8318d155d4f892b49f2c" }
8890
opentelemetry-proto.workspace = true
8991
operator.workspace = true
9092
otel-arrow-rust.workspace = true

src/servers/src/mysql/handler.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,17 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
455455
let query_ctx = self.session.new_query_context();
456456
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
457457
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
458-
let (params, columns) = self
458+
let (params, columns) = match self
459459
.do_prepare(raw_query, query_ctx.clone(), stmt_key)
460-
.await?;
460+
.await
461+
{
462+
Ok(x) => x,
463+
Err(e) => {
464+
let (kind, msg) = handle_err(e, query_ctx.clone());
465+
w.error(kind, msg.as_bytes()).await?;
466+
return Ok(());
467+
}
468+
};
461469
debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
462470
w.reply(stmt_id, &params, &columns).await?;
463471
crate::metrics::METRIC_MYSQL_PREPARED_COUNT

src/servers/src/mysql/writer.rs

Lines changed: 85 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use crate::metrics::*;
4747

4848
/// Try to write multiple output to the writer if possible.
4949
pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
50-
w: QueryResultWriter<'_, W>,
50+
mut writer: QueryResultWriter<'_, W>,
5151
query_context: QueryContextRef,
5252
session: SessionRef,
5353
outputs: Vec<Result<Output>>,
@@ -56,21 +56,87 @@ pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
5656
session.add_warning(warning);
5757
}
5858

59-
let mut writer = Some(MysqlResultWriter::new(
60-
w,
61-
query_context.clone(),
62-
session.clone(),
63-
));
59+
enum Response {
60+
ResultSet {
61+
columns: Vec<Column>,
62+
stream: SendableRecordBatchStream,
63+
},
64+
AffectedRows(usize),
65+
}
66+
67+
let mut responses = Vec::with_capacity(outputs.len());
6468
for output in outputs {
65-
let result_writer = writer.take().context(error::InternalSnafu {
66-
err_msg: "Sending multiple result set is unsupported",
67-
})?;
68-
writer = result_writer.try_write_one(output).await?;
69+
match output {
70+
Ok(x) => {
71+
let output = match x.data {
72+
OutputData::Stream(stream) => either::Left(stream),
73+
OutputData::RecordBatches(record_batches) => {
74+
either::Left(record_batches.as_stream())
75+
}
76+
OutputData::AffectedRows(rows) => either::Right(rows),
77+
};
78+
responses.push(match output {
79+
either::Left(stream) => {
80+
let schema = stream.schema();
81+
let columns = match create_mysql_column_def(&schema) {
82+
Ok(columns) => columns,
83+
Err(e) => {
84+
MysqlResultWriter::write_query_error(
85+
e,
86+
writer,
87+
query_context.clone(),
88+
)
89+
.await?;
90+
return Ok(());
91+
}
92+
};
93+
Response::ResultSet { columns, stream }
94+
}
95+
either::Right(rows) => Response::AffectedRows(rows),
96+
});
97+
}
98+
Err(e) => {
99+
MysqlResultWriter::write_query_error(e, writer, query_context.clone()).await?;
100+
return Ok(());
101+
}
102+
}
69103
}
70104

71-
if let Some(result_writer) = writer {
72-
result_writer.finish().await?;
105+
for response in &mut responses {
106+
writer = match response {
107+
Response::ResultSet { columns, stream } => {
108+
let mut row_writer = writer.start(columns).await?;
109+
while let Some(record_batch) = stream.next().await {
110+
match record_batch {
111+
Ok(record_batch) => {
112+
if let Err(e) = MysqlResultWriter::write_recordbatch(
113+
&mut row_writer,
114+
record_batch,
115+
query_context.clone(),
116+
)
117+
.await
118+
{
119+
let (kind, err) = handle_err(e, query_context);
120+
row_writer.finish_error(kind, &err.as_bytes()).await?;
121+
return Ok(());
122+
}
123+
}
124+
Err(e) => {
125+
let (kind, err) = handle_err(e, query_context);
126+
row_writer.finish_error(kind, &err.as_bytes()).await?;
127+
return Ok(());
128+
}
129+
}
130+
}
131+
row_writer.finish_one().await?
132+
}
133+
Response::AffectedRows(rows) => {
134+
MysqlResultWriter::write_affected_rows(writer, *rows, &session).await?
135+
}
136+
}
73137
}
138+
139+
writer.no_more_results().await?;
74140
Ok(())
75141
}
76142

@@ -97,75 +163,10 @@ pub fn handle_err(e: impl ErrorExt, query_ctx: QueryContextRef) -> (ErrorKind, S
97163
(kind, err_msg)
98164
}
99165

100-
struct QueryResult {
101-
schema: SchemaRef,
102-
stream: SendableRecordBatchStream,
103-
}
104-
105-
pub struct MysqlResultWriter<'a, W: AsyncWrite + Unpin> {
106-
writer: QueryResultWriter<'a, W>,
107-
query_context: QueryContextRef,
108-
session: SessionRef,
109-
}
110-
111-
impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
112-
pub fn new(
113-
writer: QueryResultWriter<'a, W>,
114-
query_context: QueryContextRef,
115-
session: SessionRef,
116-
) -> MysqlResultWriter<'a, W> {
117-
MysqlResultWriter::<'a, W> {
118-
writer,
119-
query_context,
120-
session,
121-
}
122-
}
123-
124-
/// Try to write one result set. If there are more than one result set, return `Some`.
125-
pub async fn try_write_one(
126-
self,
127-
output: Result<Output>,
128-
) -> io::Result<Option<MysqlResultWriter<'a, W>>> {
129-
// We don't support sending multiple query result because the RowWriter's lifetime is bound to
130-
// a local variable.
131-
match output {
132-
Ok(output) => match output.data {
133-
OutputData::Stream(stream) => {
134-
let query_result = QueryResult {
135-
schema: stream.schema(),
136-
stream,
137-
};
138-
Self::write_query_result(query_result, self.writer, self.query_context).await?;
139-
}
140-
OutputData::RecordBatches(recordbatches) => {
141-
let query_result = QueryResult {
142-
schema: recordbatches.schema(),
143-
stream: recordbatches.as_stream(),
144-
};
145-
Self::write_query_result(query_result, self.writer, self.query_context).await?;
146-
}
147-
OutputData::AffectedRows(rows) => {
148-
let next_writer =
149-
Self::write_affected_rows(self.writer, rows, &self.session).await?;
150-
return Ok(Some(MysqlResultWriter::new(
151-
next_writer,
152-
self.query_context,
153-
self.session,
154-
)));
155-
}
156-
},
157-
Err(error) => Self::write_query_error(error, self.writer, self.query_context).await?,
158-
}
159-
Ok(None)
160-
}
166+
struct MysqlResultWriter;
161167

162-
/// Indicate no more result set to write. No need to call this if there is only one result set.
163-
pub async fn finish(self) -> Result<()> {
164-
self.writer.no_more_results().await?;
165-
Ok(())
166-
}
167-
168-
async fn write_affected_rows(
168+
impl MysqlResultWriter {
169+
async fn write_affected_rows<'a, W: AsyncWrite + Unpin>(
169170
w: QueryResultWriter<'a, W>,
170171
rows: usize,
171172
session: &SessionRef,
@@ -182,54 +183,12 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
182183
Ok(next_writer)
183184
}
184185

185-
async fn write_query_result(
186-
mut query_result: QueryResult,
187-
writer: QueryResultWriter<'a, W>,
188-
query_context: QueryContextRef,
189-
) -> io::Result<()> {
190-
match create_mysql_column_def(&query_result.schema) {
191-
Ok(column_def) => {
192-
// The RowWriter's lifetime is bound to `column_def` thus we can't use finish_one()
193-
// to return a new QueryResultWriter.
194-
let mut row_writer = writer.start(&column_def).await?;
195-
while let Some(record_batch) = query_result.stream.next().await {
196-
match record_batch {
197-
Ok(record_batch) => {
198-
if let Err(e) = Self::write_recordbatch(
199-
&mut row_writer,
200-
record_batch,
201-
query_context.clone(),
202-
&query_result.schema,
203-
)
204-
.await
205-
{
206-
let (kind, err) = handle_err(e, query_context);
207-
row_writer.finish_error(kind, &err.as_bytes()).await?;
208-
return Ok(());
209-
}
210-
}
211-
Err(e) => {
212-
let (kind, err) = handle_err(e, query_context);
213-
debug!("Failed to get result, kind: {:?}, err: {}", kind, err);
214-
row_writer.finish_error(kind, &err.as_bytes()).await?;
215-
216-
return Ok(());
217-
}
218-
}
219-
}
220-
row_writer.finish().await?;
221-
Ok(())
222-
}
223-
Err(error) => Self::write_query_error(error, writer, query_context).await,
224-
}
225-
}
226-
227-
async fn write_recordbatch(
228-
row_writer: &mut RowWriter<'_, W>,
186+
async fn write_recordbatch<W: AsyncWrite + Unpin>(
187+
row_writer: &mut RowWriter<'_, '_, W>,
229188
record_batch: RecordBatch,
230189
query_context: QueryContextRef,
231-
schema: &SchemaRef,
232190
) -> Result<()> {
191+
let schema = record_batch.schema.clone();
233192
let record_batch = record_batch.into_df_record_batch();
234193
for i in 0..record_batch.num_rows() {
235194
for (j, column) in record_batch.columns().iter().enumerate() {
@@ -358,7 +317,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
358317
Ok(())
359318
}
360319

361-
async fn write_query_error(
320+
async fn write_query_error<'a, W: AsyncWrite + Unpin>(
362321
error: impl ErrorExt,
363322
w: QueryResultWriter<'a, W>,
364323
query_context: QueryContextRef,

src/servers/tests/mod.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@ use api::v1::query_request::Query;
1919
use async_trait::async_trait;
2020
use catalog::memory::MemoryCatalogManager;
2121
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
22+
use common_error::ext::BoxedError;
2223
use common_grpc::flight::do_put::DoPutResponse;
2324
use common_query::Output;
2425
use datafusion_expr::LogicalPlan;
26+
use futures_util::TryFutureExt;
2527
use query::options::QueryOptions;
26-
use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
28+
use query::parser::{PromQuery, QueryStatement};
2729
use query::query_engine::DescribeResult;
2830
use query::{QueryEngineFactory, QueryEngineRef};
29-
use servers::error::{NotSupportedSnafu, Result};
31+
use servers::error::{ExecuteQuerySnafu, NotSupportedSnafu, Result};
3032
use servers::query_handler::grpc::GrpcQueryHandler;
3133
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
3234
use session::context::QueryContextRef;
33-
use snafu::ensure;
35+
use snafu::{ResultExt, ensure};
36+
use sql::parser::{ParseOptions, ParserContext};
3437
use sql::statements::statement::Statement;
3538
use table::TableRef;
3639

@@ -52,15 +55,29 @@ impl DummyInstance {
5255
#[async_trait]
5356
impl SqlQueryHandler for DummyInstance {
5457
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
55-
let stmt = QueryLanguageParser::parse_sql(query, &query_ctx).unwrap();
56-
let plan = self
57-
.query_engine
58-
.planner()
59-
.plan(&stmt, query_ctx.clone())
60-
.await
61-
.unwrap();
62-
let output = self.query_engine.execute(plan, query_ctx).await.unwrap();
63-
vec![Ok(output)]
58+
let mut results = vec![];
59+
60+
let statements = ParserContext::create_with_dialect(
61+
query,
62+
query_ctx.sql_dialect(),
63+
ParseOptions::default(),
64+
)
65+
.map(|x| x.into_iter().map(QueryStatement::Sql).collect::<Vec<_>>())
66+
.unwrap();
67+
68+
for statement in &statements {
69+
let result = self
70+
.query_engine
71+
.planner()
72+
.plan(statement, query_ctx.clone())
73+
.and_then(|plan| self.query_engine.execute(plan, query_ctx.clone()))
74+
.await
75+
.map_err(BoxedError::new)
76+
.context(ExecuteQuerySnafu);
77+
results.push(result);
78+
}
79+
80+
results
6481
}
6582

6683
async fn do_exec_plan(

src/servers/tests/mysql/mysql_server_test.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,17 @@ async fn test_query_prepared() -> Result<()> {
505505

506506
test_prepare_all_type(column_schemas, columns, &mut connection).await;
507507

508+
match connection
509+
.prep("SELECT `timestamp` FROM t WHERE `timestamp` > NOW() - INTERVAL '1 hour'")
510+
.await
511+
{
512+
Err(mysql_async::Error::Server(e)) => assert_eq!(
513+
"ERROR HY000 (1210): (InvalidArguments): Invalid prepare statement: Invalid SQL syntax: sql parser error: INTERVAL requires a unit after the literal value",
514+
e.to_string()
515+
),
516+
_ => unreachable!(),
517+
}
518+
508519
Ok(())
509520
}
510521

0 commit comments

Comments
 (0)