diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 31504638..e18f537f 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -9,9 +9,12 @@ jobs: name: lint-pr-changes runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/setup-go@v3 + with: + go-version: 1.18 + - uses: actions/checkout@v3 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 with: - version: v1.42.0 + version: latest only-new-issues: true diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 214186f0..776b310b 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -49,3 +49,5 @@ steps: env: disable.coverage.autogenerate: 'true' + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: ‘Component Detection’ diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 8cb8facc..4af4164a 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -88,6 +88,11 @@ func newCommands() Commands { action: connectCommand, name: "CONNECT", }, + "EXEC": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`), + action: execCommand, + name: "EXEC", + }, } } @@ -172,6 +177,9 @@ func goCommand(s *Sqlcmd, args []string, line uint) error { if len(args) > 0 { cnt := strings.TrimSpace(args[0]) if cnt != "" { + if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { + return err + } _, err = fmt.Sscanf(cnt, "%d", &n) } } @@ -245,7 +253,8 @@ func readFileCommand(s *Sqlcmd, args []string, line uint) error { if args == nil || len(args) != 1 { return InvalidCommandError(":R", line) } - return s.IncludeFile(resolveArgumentVariables(s, []rune(args[0])), false) + fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) + return s.IncludeFile(fileName, false) } // setVarCommand parses a variable setting and applies it to the current Sqlcmd variables @@ -345,10 +354,10 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error { } connect := s.Connect - connect.UserName = resolveArgumentVariables(s, []rune(arguments.Username)) - connect.Password = resolveArgumentVariables(s, []rune(arguments.Password)) - connect.ServerName = resolveArgumentVariables(s, []rune(arguments.Server)) - timeout := resolveArgumentVariables(s, []rune(arguments.LoginTimeout)) + connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false) + connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false) + connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false) + timeout, _ := resolveArgumentVariables(s, []rune(arguments.LoginTimeout), false) if timeout != "" { if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { if timeoutSeconds < 0 { @@ -364,7 +373,26 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error { return nil } -func resolveArgumentVariables(s *Sqlcmd, arg []rune) string { +func execCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("EXEC", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("EXEC", line) + } + if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { + return err + } else { + cmd := sysCommand(cmdLine) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + _ = cmd.Run() + } + return nil +} + +func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { var b *strings.Builder end := len(arg) for i := 0; i < end; { @@ -383,6 +411,9 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string { } b.WriteString(val) } else { + if failOnUnresolved { + return "", UndefinedVariable(varName) + } _, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol)) if b != nil { b.WriteString(string(arg[i : vl+1])) @@ -403,7 +434,7 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string { } } if b == nil { - return string(arg) + return string(arg), nil } - return b.String() + return b.String(), nil } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 90a32928..a19d626b 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -46,6 +46,8 @@ func TestCommandParsing(t *testing.T) { {`EXIT `, "EXIT", []string{""}}, {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, {`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}}, + {`:!! notepad`, "EXEC", []string{"notepad"}}, + {` !! dir c:\`, "EXEC", []string{`dir c:\`}}, } for _, test := range commands { @@ -242,9 +244,26 @@ func TestResolveArgumentVariables(t *testing.T) { defer buf.Close() s.SetError(buf) for _, test := range args { - actual := resolveArgumentVariables(s, []rune(test.arg)) + actual, _ := resolveArgumentVariables(s, []rune(test.arg), false) assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg) assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg) buf.buf.Reset() } + actual, err := resolveArgumentVariables(s, []rune("$(var1)$(var2)"), true) + if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") { + assert.Empty(t, actual, "fail on unresolved variable") + } +} + +func TestExecCommand(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + s.vars.Set("var1", "hello") + buf := &memoryBuffer{buf: new(bytes.Buffer)} + defer buf.Close() + s.SetOutput(buf) + err := execCommand(s, []string{`echo $(var1)`}, 1) + if assert.NoError(t, err, "execCommand with valid arguments") { + assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output") + } } diff --git a/pkg/sqlcmd/exec_darwin.go b/pkg/sqlcmd/exec_darwin.go new file mode 100644 index 00000000..298c6a73 --- /dev/null +++ b/pkg/sqlcmd/exec_darwin.go @@ -0,0 +1,16 @@ +package sqlcmd + +import ( + "os/exec" +) + +func sysCommand(arg string) *exec.Cmd { + cmd := exec.Command(comSpec(), "-c", arg) + return cmd +} + +// comSpec returns the path of the command shell executable +func comSpec() string { + // /bin/sh will be a link to the shell + return `/bin/sh` +} diff --git a/pkg/sqlcmd/exec_linux.go b/pkg/sqlcmd/exec_linux.go new file mode 100644 index 00000000..298c6a73 --- /dev/null +++ b/pkg/sqlcmd/exec_linux.go @@ -0,0 +1,16 @@ +package sqlcmd + +import ( + "os/exec" +) + +func sysCommand(arg string) *exec.Cmd { + cmd := exec.Command(comSpec(), "-c", arg) + return cmd +} + +// comSpec returns the path of the command shell executable +func comSpec() string { + // /bin/sh will be a link to the shell + return `/bin/sh` +} diff --git a/pkg/sqlcmd/exec_windows.go b/pkg/sqlcmd/exec_windows.go new file mode 100644 index 00000000..b33ee45f --- /dev/null +++ b/pkg/sqlcmd/exec_windows.go @@ -0,0 +1,25 @@ +package sqlcmd + +import ( + "os" + "os/exec" + "syscall" +) + +func sysCommand(arg string) *exec.Cmd { + cmd := exec.Command(comSpec()) + cmd.SysProcAttr = &syscall.SysProcAttr{CmdLine: cmd.Path + " " + comArgs(arg)} + return cmd +} + +// comSpec returns the path of the command shell executable +func comSpec() string { + if cmd, ok := os.LookupEnv("COMSPEC"); ok { + return cmd + } + return `C:\Windows\System32\cmd.exe` +} + +func comArgs(args string) string { + return `/c ` + args +}