Skip to content

Commit 9a6b932

Browse files
authored
fix: EXIT parameter merge with current batch (#455)
1 parent 930c299 commit 9a6b932

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

pkg/sqlcmd/commands.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,28 +182,34 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error {
182182
if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") {
183183
return InvalidCommandError("EXIT", line)
184184
}
185-
// First we run the current batch
186-
query := s.batch.String()
187-
if query != "" {
188-
query = s.getRunnableQuery(query)
189-
if exitCode, err := s.runQuery(query); err != nil {
190-
s.Exitcode = exitCode
191-
return ErrExitRequested
192-
}
193-
}
194-
query = strings.TrimSpace(params[1 : len(params)-1])
195-
if len(query) > 0 {
196-
s.batch.Reset([]rune(query))
185+
// First we save the current batch
186+
query1 := s.batch.String()
187+
if len(query1) > 0 {
188+
query1 = s.getRunnableQuery(query1)
189+
}
190+
// Now parse the params of EXIT as a batch without commands
191+
cmd := s.batch.cmd
192+
s.batch.cmd = nil
193+
defer func() {
194+
s.batch.cmd = cmd
195+
}()
196+
query2 := strings.TrimSpace(params[1 : len(params)-1])
197+
if len(query2) > 0 {
198+
s.batch.Reset([]rune(query2))
197199
_, _, err := s.batch.Next()
198200
if err != nil {
199201
return err
200202
}
201-
query = s.batch.String()
202-
if s.batch.String() != "" {
203-
query = s.getRunnableQuery(query)
204-
s.Exitcode, _ = s.runQuery(query)
203+
query2 = s.batch.String()
204+
if len(query2) > 0 {
205+
query2 = s.getRunnableQuery(query2)
205206
}
206207
}
208+
209+
if len(query1) > 0 || len(query2) > 0 {
210+
query := query1 + SqlcmdEol + query2
211+
s.Exitcode, _ = s.runQuery(query)
212+
}
207213
return ErrExitRequested
208214
}
209215

pkg/sqlcmd/commands_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,23 @@ func TestEchoInput(t *testing.T) {
375375
assert.Equal(t, "set nocount on"+SqlcmdEol+"select 100"+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output with echo true")
376376
}
377377
}
378+
379+
func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) {
380+
s, buf := setupSqlCmdWithMemoryOutput(t)
381+
defer buf.Close()
382+
c := []string{"set nocount on", "declare @v integer = 2", "select 1", "exit(select @v)"}
383+
err := runSqlCmd(t, s, c)
384+
if assert.NoError(t, err, "exit should not error") {
385+
output := buf.buf.String()
386+
assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"2"+SqlcmdEol+SqlcmdEol, output, "Incorrect output")
387+
assert.Equal(t, 2, s.Exitcode, "exit should set Exitcode")
388+
}
389+
s, buf1 := setupSqlCmdWithMemoryOutput(t)
390+
defer buf1.Close()
391+
c = []string{"set nocount on", "select 1", "exit(select @v)"}
392+
err = runSqlCmd(t, s, c)
393+
if assert.NoError(t, err, "exit should not error") {
394+
assert.Equal(t, -101, s.Exitcode, "exit should not set Exitcode on script error")
395+
}
396+
397+
}

0 commit comments

Comments
 (0)