Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ datafusion-expr.workspace = true
datafusion-pg-catalog.workspace = true
datatypes.workspace = true
derive_builder.workspace = true
either.workspace = true
futures.workspace = true
futures-util.workspace = true
headers = "0.4"
Expand All @@ -84,7 +85,8 @@ notify.workspace = true
object-pool = "0.5"
once_cell.workspace = true
openmetrics-parser = "0.4"
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", tag = "v0.10.0" }
# Wait for https://github.com/databendlabs/opensrv/pull/81
opensrv-mysql = { git = "https://github.com/GreptimeTeam/opensrv", rev = "6c5a451544194b7bb60a8318d155d4f892b49f2c" }
opentelemetry-proto.workspace = true
operator.workspace = true
otel-arrow-rust.workspace = true
Expand Down
12 changes: 10 additions & 2 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,17 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
let query_ctx = self.session.new_query_context();
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
let (params, columns) = self
let (params, columns) = match self
.do_prepare(raw_query, query_ctx.clone(), stmt_key)
.await?;
.await
{
Ok(x) => x,
Err(e) => {
let (kind, msg) = handle_err(e, query_ctx.clone());
w.error(kind, msg.as_bytes()).await?;
return Ok(());
}
};
debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
w.reply(stmt_id, &params, &columns).await?;
crate::metrics::METRIC_MYSQL_PREPARED_COUNT
Expand Down
211 changes: 85 additions & 126 deletions src/servers/src/mysql/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::metrics::*;

/// Try to write multiple output to the writer if possible.
pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
w: QueryResultWriter<'_, W>,
mut writer: QueryResultWriter<'_, W>,
query_context: QueryContextRef,
session: SessionRef,
outputs: Vec<Result<Output>>,
Expand All @@ -56,21 +56,87 @@ pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
session.add_warning(warning);
}

let mut writer = Some(MysqlResultWriter::new(
w,
query_context.clone(),
session.clone(),
));
enum Response {
ResultSet {
columns: Vec<Column>,
stream: SendableRecordBatchStream,
},
AffectedRows(usize),
}

let mut responses = Vec::with_capacity(outputs.len());
for output in outputs {
let result_writer = writer.take().context(error::InternalSnafu {
err_msg: "Sending multiple result set is unsupported",
})?;
writer = result_writer.try_write_one(output).await?;
match output {
Ok(x) => {
let output = match x.data {
OutputData::Stream(stream) => either::Left(stream),
OutputData::RecordBatches(record_batches) => {
either::Left(record_batches.as_stream())
}
OutputData::AffectedRows(rows) => either::Right(rows),
};
responses.push(match output {
either::Left(stream) => {
let schema = stream.schema();
let columns = match create_mysql_column_def(&schema) {
Ok(columns) => columns,
Err(e) => {
MysqlResultWriter::write_query_error(
e,
writer,
query_context.clone(),
)
.await?;
return Ok(());
}
};
Response::ResultSet { columns, stream }
}
either::Right(rows) => Response::AffectedRows(rows),
});
}
Err(e) => {
MysqlResultWriter::write_query_error(e, writer, query_context.clone()).await?;
return Ok(());
Comment thread
MichaelScofield marked this conversation as resolved.
}
}
}

if let Some(result_writer) = writer {
result_writer.finish().await?;
for response in &mut responses {
writer = match response {
Response::ResultSet { columns, stream } => {
let mut row_writer = writer.start(columns).await?;
while let Some(record_batch) = stream.next().await {
match record_batch {
Ok(record_batch) => {
if let Err(e) = MysqlResultWriter::write_recordbatch(
&mut row_writer,
record_batch,
query_context.clone(),
)
.await
{
let (kind, err) = handle_err(e, query_context);
row_writer.finish_error(kind, &err.as_bytes()).await?;
return Ok(());
Comment thread
sunng87 marked this conversation as resolved.
}
}
Err(e) => {
let (kind, err) = handle_err(e, query_context);
row_writer.finish_error(kind, &err.as_bytes()).await?;
return Ok(());
}
}
}
row_writer.finish_one().await?
}
Response::AffectedRows(rows) => {
MysqlResultWriter::write_affected_rows(writer, *rows, &session).await?
}
}
}

writer.no_more_results().await?;
Ok(())
}

Expand All @@ -97,75 +163,10 @@ pub fn handle_err(e: impl ErrorExt, query_ctx: QueryContextRef) -> (ErrorKind, S
(kind, err_msg)
}

struct QueryResult {
schema: SchemaRef,
stream: SendableRecordBatchStream,
}

pub struct MysqlResultWriter<'a, W: AsyncWrite + Unpin> {
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
session: SessionRef,
}

impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
pub fn new(
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
session: SessionRef,
) -> MysqlResultWriter<'a, W> {
MysqlResultWriter::<'a, W> {
writer,
query_context,
session,
}
}

/// Try to write one result set. If there are more than one result set, return `Some`.
pub async fn try_write_one(
self,
output: Result<Output>,
) -> io::Result<Option<MysqlResultWriter<'a, W>>> {
// We don't support sending multiple query result because the RowWriter's lifetime is bound to
// a local variable.
match output {
Ok(output) => match output.data {
OutputData::Stream(stream) => {
let query_result = QueryResult {
schema: stream.schema(),
stream,
};
Self::write_query_result(query_result, self.writer, self.query_context).await?;
}
OutputData::RecordBatches(recordbatches) => {
let query_result = QueryResult {
schema: recordbatches.schema(),
stream: recordbatches.as_stream(),
};
Self::write_query_result(query_result, self.writer, self.query_context).await?;
}
OutputData::AffectedRows(rows) => {
let next_writer =
Self::write_affected_rows(self.writer, rows, &self.session).await?;
return Ok(Some(MysqlResultWriter::new(
next_writer,
self.query_context,
self.session,
)));
}
},
Err(error) => Self::write_query_error(error, self.writer, self.query_context).await?,
}
Ok(None)
}
struct MysqlResultWriter;

/// Indicate no more result set to write. No need to call this if there is only one result set.
pub async fn finish(self) -> Result<()> {
self.writer.no_more_results().await?;
Ok(())
}

async fn write_affected_rows(
impl MysqlResultWriter {
async fn write_affected_rows<'a, W: AsyncWrite + Unpin>(
w: QueryResultWriter<'a, W>,
rows: usize,
session: &SessionRef,
Expand All @@ -182,54 +183,12 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
Ok(next_writer)
}

async fn write_query_result(
mut query_result: QueryResult,
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
) -> io::Result<()> {
match create_mysql_column_def(&query_result.schema) {
Ok(column_def) => {
// The RowWriter's lifetime is bound to `column_def` thus we can't use finish_one()
// to return a new QueryResultWriter.
let mut row_writer = writer.start(&column_def).await?;
while let Some(record_batch) = query_result.stream.next().await {
match record_batch {
Ok(record_batch) => {
if let Err(e) = Self::write_recordbatch(
&mut row_writer,
record_batch,
query_context.clone(),
&query_result.schema,
)
.await
{
let (kind, err) = handle_err(e, query_context);
row_writer.finish_error(kind, &err.as_bytes()).await?;
return Ok(());
}
}
Err(e) => {
let (kind, err) = handle_err(e, query_context);
debug!("Failed to get result, kind: {:?}, err: {}", kind, err);
row_writer.finish_error(kind, &err.as_bytes()).await?;

return Ok(());
}
}
}
row_writer.finish().await?;
Ok(())
}
Err(error) => Self::write_query_error(error, writer, query_context).await,
}
}

async fn write_recordbatch(
row_writer: &mut RowWriter<'_, W>,
async fn write_recordbatch<W: AsyncWrite + Unpin>(
row_writer: &mut RowWriter<'_, '_, W>,
record_batch: RecordBatch,
query_context: QueryContextRef,
schema: &SchemaRef,
) -> Result<()> {
let schema = record_batch.schema.clone();
let record_batch = record_batch.into_df_record_batch();
for i in 0..record_batch.num_rows() {
for (j, column) in record_batch.columns().iter().enumerate() {
Expand Down Expand Up @@ -358,7 +317,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
Ok(())
}

async fn write_query_error(
async fn write_query_error<'a, W: AsyncWrite + Unpin>(
error: impl ErrorExt,
w: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
Expand Down
41 changes: 29 additions & 12 deletions src/servers/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ use api::v1::query_request::Query;
use async_trait::async_trait;
use catalog::memory::MemoryCatalogManager;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::ext::BoxedError;
use common_grpc::flight::do_put::DoPutResponse;
use common_query::Output;
use datafusion_expr::LogicalPlan;
use futures_util::TryFutureExt;
use query::options::QueryOptions;
use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
use query::parser::{PromQuery, QueryStatement};
use query::query_engine::DescribeResult;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error::{NotSupportedSnafu, Result};
use servers::error::{ExecuteQuerySnafu, NotSupportedSnafu, Result};
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
use session::context::QueryContextRef;
use snafu::ensure;
use snafu::{ResultExt, ensure};
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;
use table::TableRef;

Expand All @@ -52,15 +55,29 @@ impl DummyInstance {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
let stmt = QueryLanguageParser::parse_sql(query, &query_ctx).unwrap();
let plan = self
.query_engine
.planner()
.plan(&stmt, query_ctx.clone())
.await
.unwrap();
let output = self.query_engine.execute(plan, query_ctx).await.unwrap();
vec![Ok(output)]
let mut results = vec![];

let statements = ParserContext::create_with_dialect(
query,
query_ctx.sql_dialect(),
ParseOptions::default(),
)
.map(|x| x.into_iter().map(QueryStatement::Sql).collect::<Vec<_>>())
.unwrap();

for statement in &statements {
let result = self
.query_engine
.planner()
.plan(statement, query_ctx.clone())
.and_then(|plan| self.query_engine.execute(plan, query_ctx.clone()))
.await
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu);
results.push(result);
}

results
}

async fn do_exec_plan(
Expand Down
11 changes: 11 additions & 0 deletions src/servers/tests/mysql/mysql_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,17 @@ async fn test_query_prepared() -> Result<()> {

test_prepare_all_type(column_schemas, columns, &mut connection).await;

match connection
.prep("SELECT `timestamp` FROM t WHERE `timestamp` > NOW() - INTERVAL '1 hour'")
.await
{
Err(mysql_async::Error::Server(e)) => assert_eq!(
"ERROR HY000 (1210): (InvalidArguments): Invalid prepare statement: Invalid SQL syntax: sql parser error: INTERVAL requires a unit after the literal value",
e.to_string()
),
_ => unreachable!(),
}

Ok(())
}

Expand Down
Loading
Loading