diff --git a/server/server.go b/server/server.go index dd1dd85..bf5c59f 100644 --- a/server/server.go +++ b/server/server.go @@ -63,6 +63,16 @@ type server struct { sessions map[string]*session } +type emptyRowIterator struct{} + +func (ri emptyRowIterator) ResultSet() []ResultItem { + return nil +} + +func (ri emptyRowIterator) Do(f func([]interface{}) error) error { + return nil +} + func (s *server) ApplyDDL(ctx context.Context, databaseName string, stmt ast.DDL) error { db, err := s.getOrCreateDatabase(databaseName) if err != nil { @@ -594,7 +604,7 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp Lexer: &parser.Lexer{ File: &token.File{FilePath: "", Buffer: req.Sql}, }, - }).ParseQuery() + }).ParseStatement() if err != nil { return status.Errorf(codes.InvalidArgument, "Syntax error: %q: %v", req.Sql, err) } @@ -618,7 +628,25 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp params[key] = v } - iter, err := session.database.Query(ctx, tx, stmt, params) + stats := queryStats{ + Mode: req.QueryMode, + ReceivedAt: receivedAt, + QueryText: req.Sql, + } + + var iter RowIterator + switch stmt := stmt.(type) { + case *ast.QueryStatement: + iter, err = session.database.Query(ctx, tx, stmt, params) + case ast.DML: + var result *spannerpb.ResultSet + result, err = s.executeParsedDML(ctx, session, tx, stmt, req.GetParams(), req.GetParamTypes()) + stats.RowCount = result.GetStats().GetRowCountExact() + iter = emptyRowIterator{} + default: + return status.Errorf(codes.InvalidArgument, "Unknown query: %q", req.Sql) + } + if err != nil { if !tx.Available() { return checkAvailability() @@ -626,11 +654,6 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp return err } - stats := queryStats{ - Mode: req.QueryMode, - ReceivedAt: receivedAt, - QueryText: req.Sql, - } if err := sendResult(stream, tx, iter, txCreated, stats); err != nil { if !tx.Available() { return checkAvailability() @@ -728,8 +751,11 @@ func (s *server) executeDML(ctx context.Context, session *session, tx *transacti return nil, status.Errorf(codes.InvalidArgument, "%q is not valid DML: %v", stmt.Sql, err) } - fields := stmt.GetParams().GetFields() - paramTypes := stmt.ParamTypes + return s.executeParsedDML(ctx, session, tx, dml, stmt.GetParams(), stmt.GetParamTypes()) +} + +func (s *server) executeParsedDML(ctx context.Context, session *session, tx *transaction, dml ast.DML, paramStruct *structpb.Struct, paramTypes map[string]*spannerpb.Type) (*spannerpb.ResultSet, error) { + fields := paramStruct.GetFields() params := make(map[string]Value, len(fields)) defaultType := &spannerpb.Type{Code: spannerpb.TypeCode_INT64} for key, val := range fields { @@ -883,7 +909,9 @@ func sendResult(stream spannerpb.Spanner_StreamingReadServer, tx *transaction, i return err } - qs.RowCount = rowCount + if qs.RowCount == 0 { + qs.RowCount = rowCount + } stats := &spannerpb.ResultSetStats{ QueryStats: createQueryStats(qs), }