From 2de3b75220f66bd238bd15edc4ada7edac2aec46 Mon Sep 17 00:00:00 2001 From: HATATANI Shinta <803393+apstndb@users.noreply.github.com> Date: Sat, 20 May 2023 00:26:25 +0900 Subject: [PATCH 1/3] support DMLs in ExecuteStreamingSql --- server/server.go | 67 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/server/server.go b/server/server.go index dd1dd85..cf9ab36 100644 --- a/server/server.go +++ b/server/server.go @@ -63,6 +63,16 @@ type server struct { sessions map[string]*session } +type emptyRowIterator struct{} + +func (ri emptyRowIterator) ResultSet() []ResultItem { + return nil +} + +func (ri emptyRowIterator) Do(f func([]interface{}) error) error { + return nil +} + func (s *server) ApplyDDL(ctx context.Context, databaseName string, stmt ast.DDL) error { db, err := s.getOrCreateDatabase(databaseName) if err != nil { @@ -594,7 +604,7 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp Lexer: &parser.Lexer{ File: &token.File{FilePath: "", Buffer: req.Sql}, }, - }).ParseQuery() + }).ParseStatement() if err != nil { return status.Errorf(codes.InvalidArgument, "Syntax error: %q: %v", req.Sql, err) } @@ -618,19 +628,47 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp params[key] = v } - iter, err := session.database.Query(ctx, tx, stmt, params) - if err != nil { - if !tx.Available() { - return checkAvailability() + var iter RowIterator + var stats queryStats + switch stmt := stmt.(type) { + case *ast.QueryStatement: + iter, err = session.database.Query(ctx, tx, stmt, params) + if err != nil { + if !tx.Available() { + return checkAvailability() + } + return err + } + stats = queryStats{ + Mode: req.QueryMode, + ReceivedAt: receivedAt, + QueryText: req.Sql, + } + case ast.DML: + result, err := s.executeParsedDML(ctx, session, tx, stmt, req.GetParams(), req.GetParamTypes()) + if err != nil { + if !tx.Available() { + return checkAvailability() + } + return err } - return err - } - stats := queryStats{ - Mode: req.QueryMode, - ReceivedAt: receivedAt, - QueryText: req.Sql, + if txCreated { + result.Metadata = &spannerpb.ResultSetMetadata{ + Transaction: tx.Proto(), + } + } + + stats = queryStats{ + Mode: req.QueryMode, + ReceivedAt: receivedAt, + QueryText: req.Sql, + } + iter = emptyRowIterator{} + default: + return status.Errorf(codes.InvalidArgument, "Unknown query: %q", req.Sql) } + if err := sendResult(stream, tx, iter, txCreated, stats); err != nil { if !tx.Available() { return checkAvailability() @@ -728,8 +766,11 @@ func (s *server) executeDML(ctx context.Context, session *session, tx *transacti return nil, status.Errorf(codes.InvalidArgument, "%q is not valid DML: %v", stmt.Sql, err) } - fields := stmt.GetParams().GetFields() - paramTypes := stmt.ParamTypes + return s.executeParsedDML(ctx, session, tx, dml, stmt.GetParams(), stmt.GetParamTypes()) +} + +func (s *server) executeParsedDML(ctx context.Context, session *session, tx *transaction, dml ast.DML, paramStruct *structpb.Struct, paramTypes map[string]*spannerpb.Type) (*spannerpb.ResultSet, error) { + fields := paramStruct.GetFields() params := make(map[string]Value, len(fields)) defaultType := &spannerpb.Type{Code: spannerpb.TypeCode_INT64} for key, val := range fields { From 54205a1abdbe50abe256845dd50ffd86c3dd3865 Mon Sep 17 00:00:00 2001 From: HATATANI Shinta <803393+apstndb@users.noreply.github.com> Date: Sat, 20 May 2023 00:45:13 +0900 Subject: [PATCH 2/3] refactor ExecuteStreamingSql --- server/server.go | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/server/server.go b/server/server.go index cf9ab36..0860d47 100644 --- a/server/server.go +++ b/server/server.go @@ -633,42 +633,32 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp switch stmt := stmt.(type) { case *ast.QueryStatement: iter, err = session.database.Query(ctx, tx, stmt, params) - if err != nil { - if !tx.Available() { - return checkAvailability() - } - return err - } stats = queryStats{ Mode: req.QueryMode, ReceivedAt: receivedAt, QueryText: req.Sql, } case ast.DML: - result, err := s.executeParsedDML(ctx, session, tx, stmt, req.GetParams(), req.GetParamTypes()) - if err != nil { - if !tx.Available() { - return checkAvailability() - } - return err - } - - if txCreated { - result.Metadata = &spannerpb.ResultSetMetadata{ - Transaction: tx.Proto(), - } - } - + var result *spannerpb.ResultSet + result, err = s.executeParsedDML(ctx, session, tx, stmt, req.GetParams(), req.GetParamTypes()) stats = queryStats{ Mode: req.QueryMode, ReceivedAt: receivedAt, QueryText: req.Sql, + RowCount: result.GetStats().GetRowCountExact(), } iter = emptyRowIterator{} default: return status.Errorf(codes.InvalidArgument, "Unknown query: %q", req.Sql) } + if err != nil { + if !tx.Available() { + return checkAvailability() + } + return err + } + if err := sendResult(stream, tx, iter, txCreated, stats); err != nil { if !tx.Available() { return checkAvailability() @@ -924,7 +914,9 @@ func sendResult(stream spannerpb.Spanner_StreamingReadServer, tx *transaction, i return err } - qs.RowCount = rowCount + if qs.RowCount == 0 { + qs.RowCount = rowCount + } stats := &spannerpb.ResultSetStats{ QueryStats: createQueryStats(qs), } From 30cdc097dfaf23643db17be03877e76f3e9fb18d Mon Sep 17 00:00:00 2001 From: HATATANI Shinta <803393+apstndb@users.noreply.github.com> Date: Sat, 20 May 2023 00:58:57 +0900 Subject: [PATCH 3/3] refactor ExecuteStreamingSql --- server/server.go | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/server/server.go b/server/server.go index 0860d47..bf5c59f 100644 --- a/server/server.go +++ b/server/server.go @@ -628,25 +628,20 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp params[key] = v } + stats := queryStats{ + Mode: req.QueryMode, + ReceivedAt: receivedAt, + QueryText: req.Sql, + } + var iter RowIterator - var stats queryStats switch stmt := stmt.(type) { case *ast.QueryStatement: iter, err = session.database.Query(ctx, tx, stmt, params) - stats = queryStats{ - Mode: req.QueryMode, - ReceivedAt: receivedAt, - QueryText: req.Sql, - } case ast.DML: var result *spannerpb.ResultSet result, err = s.executeParsedDML(ctx, session, tx, stmt, req.GetParams(), req.GetParamTypes()) - stats = queryStats{ - Mode: req.QueryMode, - ReceivedAt: receivedAt, - QueryText: req.Sql, - RowCount: result.GetStats().GetRowCountExact(), - } + stats.RowCount = result.GetStats().GetRowCountExact() iter = emptyRowIterator{} default: return status.Errorf(codes.InvalidArgument, "Unknown query: %q", req.Sql)