Skip to content

Commit b9675a4

Browse files
authored
Fix issues with batch and command parsing (#201)
* fix issues with batch parsing * fix bracketed identifier parsing * fix exit regex
1 parent 90ba89e commit b9675a4

File tree

8 files changed

+89
-36
lines changed

8 files changed

+89
-36
lines changed

pkg/sqlcmd/batch.go

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ parse:
109109
i, ok = readMultilineComment(b.raw, i, b.rawlen)
110110
b.comment = !ok
111111
// start of a string
112-
case c == '\'' || c == '"':
112+
case c == '\'' || c == '"' || c == '[':
113113
b.quote = c
114114
// inline sql comment, skip to end of line
115115
case c == '-' && next == '-':
@@ -145,25 +145,24 @@ parse:
145145
}
146146
}
147147
if err == nil {
148-
i = min(i, b.rawlen)
149-
empty := isEmptyLine(b.raw, 0, i)
150-
appendLine := true
151-
if !b.comment && command != nil && empty {
152-
appendLine = false
153-
}
154-
if appendLine {
155-
// any variables on the line need to be added to the global map
156-
inc := 0
157-
if b.Length > 0 {
158-
inc = len(lineend)
159-
}
160-
if b.linevarmap != nil {
161-
for v := range b.linevarmap {
162-
b.varmap[v+b.Length+inc] = b.linevarmap[v]
148+
if command == nil {
149+
i = min(i, b.rawlen)
150+
empty := i == 0
151+
appendLine := !empty || b.comment || b.quote != 0
152+
if appendLine {
153+
// any variables on the line need to be added to the global map
154+
inc := 0
155+
if b.Length > 0 {
156+
inc = len(lineend)
157+
}
158+
if b.linevarmap != nil {
159+
for v := range b.linevarmap {
160+
b.varmap[v+b.Length+inc] = b.linevarmap[v]
161+
}
163162
}
163+
// log.Printf(">> appending: `%s`", string(r[st:i]))
164+
b.append(b.raw[:i], lineend)
164165
}
165-
// log.Printf(">> appending: `%s`", string(r[st:i]))
166-
b.append(b.raw[:i], lineend)
167166
b.batchline++
168167
}
169168
b.raw = b.raw[i:]
@@ -242,11 +241,13 @@ func (b *Batch) readString(r []rune, i, end int, quote rune, line uint) (int, bo
242241
} else {
243242
return i, false, syntaxError(line)
244243
}
245-
case quote == '\'' && c == '\'' && next == '\'':
244+
case quote == '\'' && c == '\'' && next == '\'',
245+
quote == '[' && c == ']' && next == ']':
246246
i++
247247
continue
248248
case quote == '\'' && c == '\'' && prev != '\'',
249-
quote == '"' && c == '"':
249+
quote == '"' && c == '"',
250+
quote == '[' && c == ']':
250251
return i, true, nil
251252
}
252253
prev = c

pkg/sqlcmd/batch_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ func TestBatchNext(t *testing.T) {
3333
{"select 1\n:exit()", []string{"select 1"}, []string{"EXIT"}, "-"},
3434
{"select 1\n:exit (select 10)", []string{"select 1"}, []string{"EXIT"}, "-"},
3535
{"select 1\n:exit", []string{"select 1"}, []string{"EXIT"}, "-"},
36+
{"select [a'b] = 'c'", []string{"select [a'b] = 'c'"}, nil, "-"},
37+
{"select [bracket", []string{"select [bracket"}, nil, "["},
38+
{"select [bracket]]a]", []string{"select [bracket]]a]"}, nil, "-"},
39+
{"exit_1", []string{"exit_1"}, nil, "-"},
3640
}
3741
for _, test := range tests {
3842
b := NewBatch(sp(test.s, "\n"), newCommands())

pkg/sqlcmd/commands.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func newCommands() Commands {
3737
// Commands is the set of Command implementations
3838
return map[string]*Command{
3939
"EXIT": {
40-
regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT(?:[ \t]*(\(?.*\)?$)|$)`),
40+
regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`),
4141
action: exitCommand,
4242
name: "EXIT",
4343
},
@@ -186,15 +186,17 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error {
186186
}
187187
}
188188
query = strings.TrimSpace(params[1 : len(params)-1])
189-
s.batch.Reset([]rune(query))
190-
_, _, err := s.batch.Next()
191-
if err != nil {
192-
return err
193-
}
194-
query = s.batch.String()
195-
if s.batch.String() != "" {
196-
query = s.getRunnableQuery(query)
197-
s.Exitcode, _ = s.runQuery(query)
189+
if len(query) > 0 {
190+
s.batch.Reset([]rune(query))
191+
_, _, err := s.batch.Next()
192+
if err != nil {
193+
return err
194+
}
195+
query = s.batch.String()
196+
if s.batch.String() != "" {
197+
query = s.getRunnableQuery(query)
198+
s.Exitcode, _ = s.runQuery(query)
199+
}
198200
}
199201
return ErrExitRequested
200202
}

pkg/sqlcmd/commands_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ func TestCommandParsing(t *testing.T) {
4141
{` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}},
4242
{`:Setvar A1 "some value" `, "SETVAR", []string{`A1 "some value" `}},
4343
{` :Listvar`, "LISTVAR", []string{""}},
44-
{`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}},
45-
{`:EXIT ( )`, "EXIT", []string{"( )"}},
46-
{`EXIT `, "EXIT", []string{""}},
44+
{`:EXIT (select 100 as count)`, "EXIT", []string{" (select 100 as count)"}},
45+
{`:EXIT ( )`, "EXIT", []string{" ( )"}},
46+
{`EXIT `, "EXIT", []string{" "}},
4747
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
4848
{`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}},
4949
{`:!! notepad`, "EXEC", []string{" notepad"}},

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ func TestConnectionStringFromSqlCmd(t *testing.T) {
6767
}
6868
}
6969

70-
/* The following tests require a working SQL instance and rely on SqlCmd environment variables
70+
/*
71+
The following tests require a working SQL instance and rely on SqlCmd environment variables
72+
7173
to manage the initial connection string. The default connection when no environment variables are
7274
set will be to localhost using Windows auth.
73-
7475
*/
7576
func TestSqlCmdConnectDb(t *testing.T) {
7677
v := InitializeVariables(true)
@@ -185,6 +186,34 @@ func TestIncludeFileWithVariables(t *testing.T) {
185186
}
186187
}
187188

189+
func TestIncludeFileMultilineString(t *testing.T) {
190+
s, buf := setupSqlCmdWithMemoryOutput(t)
191+
defer buf.Close()
192+
dataPath := "testdata" + string(os.PathSeparator)
193+
err := s.IncludeFile(dataPath+"blanks.sql", true)
194+
if assert.NoError(t, err, "IncludeFile blanks.sql true") {
195+
assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile blanks.sql true")
196+
assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile blanks.sql true")
197+
s.SetOutput(nil)
198+
o := buf.buf.String()
199+
assert.Equal(t, "line 1"+SqlcmdEol+SqlcmdEol+SqlcmdEol+SqlcmdEol+"line2"+SqlcmdEol+SqlcmdEol, o)
200+
}
201+
}
202+
203+
func TestIncludeFileQuotedIdentifiers(t *testing.T) {
204+
s, buf := setupSqlCmdWithMemoryOutput(t)
205+
defer buf.Close()
206+
dataPath := "testdata" + string(os.PathSeparator)
207+
err := s.IncludeFile(dataPath+"quotedidentifiers.sql", true)
208+
if assert.NoError(t, err, "IncludeFile quotedidentifiers.sql true") {
209+
assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile quotedidentifiers.sql true")
210+
assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile quotedidentifiers.sql true")
211+
s.SetOutput(nil)
212+
o := buf.buf.String()
213+
assert.Equal(t, `ab 1 a"b`+SqlcmdEol+SqlcmdEol, o)
214+
}
215+
}
216+
188217
func TestGetRunnableQuery(t *testing.T) {
189218
v := InitializeVariables(false)
190219
v.Set("var1", "v1")

pkg/sqlcmd/testdata/blanks.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set nocount on
2+
:setvar l line2
3+
select 'line 1
4+
5+
6+
7+
$(l)'
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set nocount on
2+
set quoted_identifier on
3+
select [a]]b] = 'ab', "a'b" = 1, [a"b] = 'a"b'
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
set nocount on
22
:setvar hundred 100
3-
-- comment
3+
4+
-- verify fix for https://github.com/microsoft/go-sqlcmd/issues/197
5+
6+
-- Correctly handle the first line of a batch having a variable after an empty line
7+
8+
GO
9+
410
select $(hundred)
511

12+
GO

0 commit comments

Comments
 (0)