diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 9665871c..8cb8facc 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -8,6 +8,7 @@ import ( "os" "regexp" "sort" + "strconv" "strings" "github.com/alecthomas/kong" @@ -57,8 +58,7 @@ func newCommands() Commands { regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), action: errorCommand, name: "ERROR", - }, - "READFILE": { + }, "READFILE": { regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), action: readFileCommand, name: "READFILE", @@ -143,7 +143,13 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { } } query = strings.TrimSpace(params[1 : len(params)-1]) - if query != "" { + s.batch.Reset([]rune(query)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query = s.batch.String() + if s.batch.String() != "" { query = s.getRunnableQuery(query) s.Exitcode, _ = s.runQuery(query) } @@ -239,7 +245,7 @@ func readFileCommand(s *Sqlcmd, args []string, line uint) error { if args == nil || len(args) != 1 { return InvalidCommandError(":R", line) } - return s.IncludeFile(args[0], false) + return s.IncludeFile(resolveArgumentVariables(s, []rune(args[0])), false) } // setVarCommand parses a variable setting and applies it to the current Sqlcmd variables @@ -313,12 +319,17 @@ type connectData struct { Database string `short:"D"` Username string `short:"U"` Password string `short:"P"` - LoginTimeout int `short:"l"` + LoginTimeout string `short:"l"` AuthenticationMethod string `short:"G"` } func connectCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || strings.TrimSpace(args[0]) == "" { + + if len(args) == 0 { + return InvalidCommandError("CONNECT", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { return InvalidCommandError("CONNECT", line) } arguments := &connectData{} @@ -326,16 +337,25 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error { if err != nil { return InvalidCommandError("CONNECT", line) } - if _, err = parser.Parse(strings.Split(args[0], " ")); err != nil { + + // Fields removes extra whitespace. + // Note :connect doesn't support passwords with spaces + if _, err = parser.Parse(strings.Fields(cmdLine)); err != nil { return InvalidCommandError("CONNECT", line) } connect := s.Connect - connect.UserName = arguments.Username - connect.Password = arguments.Password - connect.ServerName = arguments.Server - if arguments.LoginTimeout > 0 { - connect.LoginTimeoutSeconds = arguments.LoginTimeout + connect.UserName = resolveArgumentVariables(s, []rune(arguments.Username)) + connect.Password = resolveArgumentVariables(s, []rune(arguments.Password)) + connect.ServerName = resolveArgumentVariables(s, []rune(arguments.Server)) + timeout := resolveArgumentVariables(s, []rune(arguments.LoginTimeout)) + if timeout != "" { + if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { + if timeoutSeconds < 0 { + return InvalidCommandError("CONNECT", line) + } + connect.LoginTimeoutSeconds = int(timeoutSeconds) + } } connect.AuthenticationMethod = arguments.AuthenticationMethod // If no user name is provided we switch to integrated auth @@ -343,3 +363,47 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error { // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option return nil } + +func resolveArgumentVariables(s *Sqlcmd, arg []rune) string { + var b *strings.Builder + end := len(arg) + for i := 0; i < end; { + c, next := arg[i], grab(arg, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(arg, i+2, end) + if ok { + varName := string(arg[i+2 : vl]) + val, ok := s.resolveVariable(varName) + if ok { + if b == nil { + b = new(strings.Builder) + b.Grow(len(arg)) + b.WriteString(string(arg[0:i])) + } + b.WriteString(val) + } else { + _, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol)) + if b != nil { + b.WriteString(string(arg[i : vl+1])) + } + } + i += ((vl - i) + 1) + } else { + if b != nil { + b.WriteString("$(") + } + i += 2 + } + default: + if b != nil { + b.WriteRune(c) + } + i++ + } + } + if b == nil { + return string(arg) + } + return b.String() +} diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 3a0b1382..90a32928 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/microsoft/go-mssqldb/azuread" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,6 +45,7 @@ func TestCommandParsing(t *testing.T) { {`:EXIT ( )`, "EXIT", []string{"( )"}}, {`EXIT `, "EXIT", []string{""}}, {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, + {`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}}, } for _, test := range commands { @@ -156,7 +158,7 @@ func TestListCommand(t *testing.T) { } func TestConnectCommand(t *testing.T) { - s, _ := setupSqlCmdWithMemoryOutput(t) + s, buf := setupSqlCmdWithMemoryOutput(t) prompted := false s.lineIo = &testConsole{ OnPasswordPrompt: func(prompt string) ([]byte, error) { @@ -174,19 +176,26 @@ func TestConnectCommand(t *testing.T) { c := newConnect(t) authenticationMethod := "" - if c.Password == "" { - c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID") - c.Password = os.Getenv("AZURE_CLIENT_SECRET") - authenticationMethod = "-G ActiveDirectoryServicePrincipal" - if c.Password == "" { - t.Log("Not trying :Connect with valid password due to no password being available") - return - } - err = connectCommand(s, []string{fmt.Sprintf("%s -U %s -P %s %s", c.ServerName, c.UserName, c.Password, authenticationMethod)}, 3) - assert.NoError(t, err, "connectCommand with valid parameters should not return an error") + password := "" + username := "" + if canTestAzureAuth() { + authenticationMethod = "-G " + azuread.ActiveDirectoryDefault + } + if c.Password != "" { + password = "-P " + c.Password + } + if c.UserName != "" { + username = "-U " + c.UserName + } + s.vars.Set("servername", c.ServerName) + s.vars.Set("to", "111") + buf.buf.Reset() + err = connectCommand(s, []string{fmt.Sprintf("$(servername) %s %s %s -l $(to)", username, password, authenticationMethod)}, 3) + if assert.NoError(t, err, "connectCommand with valid parameters should not return an error") { // not using assert to avoid printing passwords in the log - if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password { - t.Fatal("After connect, sqlCmd.Connect is not updated") + assert.NotContains(t, buf.buf.String(), "$(servername)", "ConnectDB should have succeeded") + if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password || s.Connect.LoginTimeoutSeconds != 111 { + t.Fatalf("After connect, sqlCmd.Connect is not updated %+v", s.Connect) } } } @@ -212,3 +221,30 @@ func TestErrorCommand(t *testing.T) { assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents") } } + +func TestResolveArgumentVariables(t *testing.T) { + type argTest struct { + arg string + val string + err string + } + + args := []argTest{ + {"$(var1)", "var1val", ""}, + {"$(var1", "$(var1", ""}, + {`C:\folder\$(var1)\$(var2)\$(var1)\file.sql`, `C:\folder\var1val\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, + {`C:\folder\$(var1\$(var2)\$(var1)\file.sql`, `C:\folder\$(var1\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, + } + vars := InitializeVariables(false) + s := New(nil, "", vars) + s.vars.Set("var1", "var1val") + buf := &memoryBuffer{buf: new(bytes.Buffer)} + defer buf.Close() + s.SetError(buf) + for _, test := range args { + actual := resolveArgumentVariables(s, []rune(test.arg)) + assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg) + assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg) + buf.buf.Reset() + } +} diff --git a/pkg/sqlcmd/parse.go b/pkg/sqlcmd/parse.go index 0fc6a47c..e9192c78 100644 --- a/pkg/sqlcmd/parse.go +++ b/pkg/sqlcmd/parse.go @@ -58,7 +58,7 @@ func readCommand(c Commands, r []rune, i, end int) (*Command, []string, int) { return cmd, args, i } -// readVariableReference returns the length of the variable reference or false if it's not a valid identifier +// readVariableReference returns the index of the end of the variable reference or false if it's not a valid identifier func readVariableReference(r []rune, i int, end int) (int, bool) { for ; i < end; i++ { if r[i] == ')' { diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 60f78263..2799b87b 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -217,7 +217,8 @@ func TestGetRunnableQuery(t *testing.T) { func TestExitInitialQuery(t *testing.T) { s, buf := setupSqlCmdWithMemoryOutput(t) defer buf.Close() - s.Query = "EXIT(SELECT '1200', 2100)" + _ = s.vars.Setvar("var1", "1200") + s.Query = "EXIT(SELECT '$(var1)', 2100)" err := s.Run(true, false) if assert.NoError(t, err, "s.Run(once = true)") { s.SetOutput(nil)