Skip to content

Commit fe6b80a

Browse files
Initialize console if password is required (#121)
In absence of console or when console reference line is set to nil or noPwd is set to true, password prompt is disabled and when console is not initialized, password cannot be read from user. This is why when -i is specified, console is not initialized and we do not see password prompt. This commit initializes console even if password is required so that password can be read interactively from the user.
1 parent 8116d58 commit fe6b80a

File tree

7 files changed

+74
-16
lines changed

7 files changed

+74
-16
lines changed

cmd/sqlcmd/main.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,28 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
201201
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
202202
}
203203

204+
func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
205+
iactive := args.InputFile == nil && args.Query == ""
206+
return iactive || connect.RequiresPassword()
207+
}
208+
204209
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
205210
wd, err := os.Getwd()
206211
if err != nil {
207212
return 1, err
208213
}
209214

210-
iactive := args.InputFile == nil && args.Query == ""
215+
var connectConfig sqlcmd.ConnectSettings
216+
setConnect(&connectConfig, args, vars)
211217
var line sqlcmd.Console = nil
212-
if iactive {
218+
if isConsoleInitializationRequired(&connectConfig, args) {
213219
line = console.NewConsole("")
214220
defer line.Close()
215221
}
216222

217223
s := sqlcmd.New(line, wd, vars)
218224
s.UnicodeOutputFile = args.UnicodeOutputFile
219-
setConnect(&s.Connect, args, vars)
225+
220226
if args.BatchTerminator != "GO" {
221227
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
222228
if err != nil {
@@ -227,7 +233,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
227233
return 1, err
228234
}
229235

230-
setConnect(&s.Connect, args, vars)
236+
s.Connect = &connectConfig
231237
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
232238
if args.OutputFile != "" {
233239
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
@@ -257,10 +263,12 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
257263
s.Query = args.Query
258264
}
259265
// connect using no overrides
260-
err = s.ConnectDb(nil, !iactive)
266+
err = s.ConnectDb(nil, line == nil)
261267
if err != nil {
262268
return 1, err
263269
}
270+
271+
iactive := args.InputFile == nil && args.Query == ""
264272
if iactive || s.Query != "" {
265273
err = s.Run(once, false)
266274
} else {

cmd/sqlcmd/main_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/alecthomas/kong"
12+
"github.com/microsoft/go-mssqldb/azuread"
1213
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
@@ -327,6 +328,54 @@ func TestMissingInputFile(t *testing.T) {
327328
assert.Equal(t, 1, exitCode, "exitCode")
328329
}
329330

331+
func TestConditionsForPasswordPrompt(t *testing.T) {
332+
333+
type test struct {
334+
authenticationMethod string
335+
inputFile []string
336+
username string
337+
pwd string
338+
expectedResult bool
339+
}
340+
tests := []test{
341+
// Positive Testcases
342+
{sqlcmd.SqlPassword, []string{""}, "someuser", "", true},
343+
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "someuser", "", true},
344+
{azuread.ActiveDirectoryPassword, []string{""}, "someuser", "", true},
345+
{azuread.ActiveDirectoryPassword, []string{"testdata/someFile.sql"}, "someuser", "", true},
346+
{azuread.ActiveDirectoryServicePrincipal, []string{""}, "someuser", "", true},
347+
{azuread.ActiveDirectoryServicePrincipal, []string{"testdata/someFile.sql"}, "someuser", "", true},
348+
{azuread.ActiveDirectoryApplication, []string{""}, "someuser", "", true},
349+
{azuread.ActiveDirectoryApplication, []string{"testdata/someFile.sql"}, "someuser", "", true},
350+
351+
//Negative Testcases
352+
{sqlcmd.NotSpecified, []string{""}, "", "", false},
353+
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "", "", false},
354+
{azuread.ActiveDirectoryDefault, []string{""}, "someuser", "", false},
355+
{azuread.ActiveDirectoryDefault, []string{"testdata/someFile.sql"}, "someuser", "", false},
356+
{azuread.ActiveDirectoryInteractive, []string{""}, "someuser", "", false},
357+
{azuread.ActiveDirectoryInteractive, []string{"testdata/someFile.sql"}, "someuser", "", false},
358+
{azuread.ActiveDirectoryManagedIdentity, []string{""}, "someuser", "", false},
359+
{azuread.ActiveDirectoryManagedIdentity, []string{"testdata/someFile.sql"}, "someuser", "", false},
360+
}
361+
362+
for _, testcase := range tests {
363+
t.Log(testcase.authenticationMethod, testcase.inputFile, testcase.username, testcase.pwd, testcase.expectedResult)
364+
args := newArguments()
365+
args.DisableCmdAndWarn = true
366+
args.InputFile = testcase.inputFile
367+
args.UserName = testcase.username
368+
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
369+
setVars(vars, &args)
370+
var connectConfig sqlcmd.ConnectSettings
371+
setConnect(&connectConfig, &args, vars)
372+
connectConfig.AuthenticationMethod = testcase.authenticationMethod
373+
connectConfig.Password = testcase.pwd
374+
assert.Equal(t, testcase.expectedResult, isConsoleInitializationRequired(&connectConfig, &args), "Unexpected test result encountered for console initialization")
375+
assert.Equal(t, testcase.expectedResult, connectConfig.RequiresPassword() && connectConfig.Password == "", "Unexpected test result encountered for password prompt conditions")
376+
}
377+
}
378+
330379
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
331380
func canTestAzureAuth() bool {
332381
server := os.Getenv(sqlcmd.SQLCMDSERVER)

pkg/sqlcmd/commands.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
353353
return InvalidCommandError("CONNECT", line)
354354
}
355355

356-
connect := s.Connect
356+
connect := *s.Connect
357357
connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false)
358358
connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false)
359359
connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false)

pkg/sqlcmd/commands_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func TestConnectCommand(t *testing.T) {
171171
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
172172
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
173173
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
174-
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")
174+
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs")
175175

176176
err = connectCommand(s, []string{}, 2)
177177
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")

pkg/sqlcmd/connect.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (connect ConnectSettings) sqlAuthentication() bool {
6464
(!connect.UseTrustedConnection && connect.authenticationMethod() == NotSpecified && connect.UserName != "")
6565
}
6666

67-
func (connect ConnectSettings) requiresPassword() bool {
67+
func (connect ConnectSettings) RequiresPassword() bool {
6868
requiresPassword := connect.sqlAuthentication()
6969
if !requiresPassword {
7070
switch connect.authenticationMethod() {

pkg/sqlcmd/sqlcmd.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type Sqlcmd struct {
6262
batch *Batch
6363
// Exitcode is returned to the operating system when the process exits
6464
Exitcode int
65-
Connect ConnectSettings
65+
Connect *ConnectSettings
6666
vars *Variables
6767
Format Formatter
6868
Query string
@@ -79,6 +79,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
7979
workingDirectory: workingDirectory,
8080
vars: vars,
8181
Cmd: newCommands(),
82+
Connect: &ConnectSettings{},
8283
}
8384
s.batch = NewBatch(s.scanNext, s.Cmd)
8485
mssql.SetContextLogger(s)
@@ -213,12 +214,12 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) {
213214
func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
214215
newConnection := connect != nil
215216
if connect == nil {
216-
connect = &s.Connect
217+
connect = s.Connect
217218
}
218219

219220
var connector driver.Connector
220221
useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication()
221-
if connect.requiresPassword() && !nopw && connect.Password == "" {
222+
if connect.RequiresPassword() && !nopw && connect.Password == "" {
222223
var err error
223224
if connect.Password, err = s.promptPassword(); err != nil {
224225
return err
@@ -259,7 +260,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
259260
s.vars.Set(SQLCMDUSER, u.Username)
260261
}
261262
if newConnection {
262-
s.Connect = *connect
263+
s.Connect = connect
263264
}
264265
if s.batch != nil {
265266
s.batch.batchline = 1

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ func TestPromptForPasswordPositive(t *testing.T) {
367367
v := InitializeVariables(true)
368368
s := New(console, "", v)
369369
// attempt without password prompt
370-
err := s.ConnectDb(&c, true)
370+
err := s.ConnectDb(c, true)
371371
assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password")
372372
assert.Error(t, err, "ConnectDb with nopw==true and no password provided")
373-
err = s.ConnectDb(&c, false)
373+
err = s.ConnectDb(c, false)
374374
assert.True(t, prompted, "ConnectDb with !nopw should prompt for password")
375375
assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt")
376376
if s.Connect.Password != password {
@@ -506,7 +506,7 @@ func canTestAzureAuth() bool {
506506
return strings.Contains(server, ".database.windows.net") && userName == ""
507507
}
508508

509-
func newConnect(t testing.TB) ConnectSettings {
509+
func newConnect(t testing.TB) *ConnectSettings {
510510
t.Helper()
511511
connect := ConnectSettings{
512512
UserName: os.Getenv(SQLCMDUSER),
@@ -518,5 +518,5 @@ func newConnect(t testing.TB) ConnectSettings {
518518
t.Log("Using ActiveDirectoryDefault")
519519
connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
520520
}
521-
return connect
521+
return &connect
522522
}

0 commit comments

Comments
 (0)