Skip to content

Commit b694b18

Browse files
Add testcase to validate the password prompt
1 parent a2b500c commit b694b18

File tree

2 files changed

+81
-23
lines changed

2 files changed

+81
-23
lines changed

cmd/sqlcmd/main.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ func (a SQLCmdArguments) authenticationMethod(hasPassword bool) string {
102102
return a.AuthenticationMethod
103103
}
104104

105+
func getSqlcmdConsole(args SQLCmdArguments) sqlcmd.Console {
106+
iactive := args.InputFile == nil && args.Query == ""
107+
uactive := args.UserName != ""
108+
var sqlcmdConsole sqlcmd.Console = nil
109+
if iactive || uactive {
110+
sqlcmdConsole = console.NewConsole("")
111+
defer sqlcmdConsole.Close()
112+
}
113+
return sqlcmdConsole
114+
}
115+
105116
func main() {
106117
ctx := kong.Parse(&args, kong.NoDefaultHelp())
107118
if args.Help {
@@ -110,9 +121,9 @@ func main() {
110121
}
111122
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
112123
setVars(vars, &args)
113-
124+
line := getSqlcmdConsole(args)
114125
// so far sqlcmd prints all the errors itself so ignore it
115-
exitCode, _ := run(vars, &args)
126+
exitCode, _ := run(vars, &args, line)
116127
os.Exit(exitCode)
117128
}
118129

@@ -201,19 +212,13 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
201212
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
202213
}
203214

204-
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
215+
func run(vars *sqlcmd.Variables, args *SQLCmdArguments, line sqlcmd.Console) (int, error) {
205216
wd, err := os.Getwd()
206217
if err != nil {
207218
return 1, err
208219
}
209220

210221
iactive := args.InputFile == nil && args.Query == ""
211-
uactive := args.UserName != ""
212-
var line sqlcmd.Console = nil
213-
if iactive || uactive {
214-
line = console.NewConsole("")
215-
defer line.Close()
216-
}
217222

218223
s := sqlcmd.New(line, wd, vars)
219224
s.UnicodeOutputFile = args.UnicodeOutputFile

cmd/sqlcmd/main_test.go

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ func TestRunInputFiles(t *testing.T) {
138138
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
139139
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
140140
setVars(vars, &args)
141-
142-
exitCode, err := run(vars, &args)
141+
sqlcmdConsole := getSqlcmdConsole(args)
142+
exitCode, err := run(vars, &args, sqlcmdConsole)
143143
assert.NoError(t, err, "run")
144144
assert.Equal(t, 0, exitCode, "exitCode")
145145
bytes, err := os.ReadFile(o.Name())
@@ -162,8 +162,8 @@ func TestUnicodeOutput(t *testing.T) {
162162
}
163163
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
164164
setVars(vars, &args)
165-
166-
exitCode, err := run(vars, &args)
165+
sqlcmdConsole := getSqlcmdConsole(args)
166+
exitCode, err := run(vars, &args, sqlcmdConsole)
167167
assert.NoError(t, err, "run")
168168
assert.Equal(t, 0, exitCode, "exitCode")
169169
bytes, err := os.ReadFile(o.Name())
@@ -214,7 +214,8 @@ func TestUnicodeInput(t *testing.T) {
214214
}
215215
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
216216
setVars(vars, &args)
217-
exitCode, err := run(vars, &args)
217+
sqlcmdConsole := getSqlcmdConsole(args)
218+
exitCode, err := run(vars, &args, sqlcmdConsole)
218219
assert.NoError(t, err, "run")
219220
assert.Equal(t, 0, exitCode, "exitCode")
220221
bytes, err := os.ReadFile(o.Name())
@@ -244,8 +245,8 @@ func TestQueryAndExit(t *testing.T) {
244245
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
245246
vars.Set("VAR1", "100")
246247
setVars(vars, &args)
247-
248-
exitCode, err := run(vars, &args)
248+
sqlcmdConsole := getSqlcmdConsole(args)
249+
exitCode, err := run(vars, &args, sqlcmdConsole)
249250
assert.NoError(t, err, "run")
250251
assert.Equal(t, 0, exitCode, "exitCode")
251252
bytes, err := os.ReadFile(o.Name())
@@ -268,17 +269,16 @@ func TestExitOnError(t *testing.T) {
268269

269270
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
270271
setVars(vars, &args)
271-
272-
exitCode, err := run(vars, &args)
272+
sqlcmdConsole := getSqlcmdConsole(args)
273+
exitCode, err := run(vars, &args, sqlcmdConsole)
273274
assert.NoError(t, err, "run")
274275
assert.Equal(t, 0, exitCode, "exitCode")
275276

276277
args.InputFile = []string{"testdata/bad.sql"}
277278

278279
vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
279280
setVars(vars, &args)
280-
281-
exitCode, err = run(vars, &args)
281+
exitCode, err = run(vars, &args, sqlcmdConsole)
282282
assert.NoError(t, err, "run")
283283
assert.Equal(t, 1, exitCode, "exitCode")
284284

@@ -301,7 +301,8 @@ func TestAzureAuth(t *testing.T) {
301301
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
302302
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
303303
setVars(vars, &args)
304-
exitCode, err := run(vars, &args)
304+
sqlcmdConsole := getSqlcmdConsole(args)
305+
exitCode, err := run(vars, &args, sqlcmdConsole)
305306
assert.NoError(t, err, "run")
306307
assert.Equal(t, 0, exitCode, "exitCode")
307308
bytes, err := os.ReadFile(o.Name())
@@ -320,13 +321,65 @@ func TestMissingInputFile(t *testing.T) {
320321

321322
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
322323
setVars(vars, &args)
323-
324-
exitCode, err := run(vars, &args)
324+
sqlcmdConsole := getSqlcmdConsole(args)
325+
exitCode, err := run(vars, &args, sqlcmdConsole)
325326
assert.Error(t, err, "run")
326327
assert.Contains(t, err.Error(), "Error occurred while opening or operating on file", "Unexpected error: "+err.Error())
327328
assert.Equal(t, 1, exitCode, "exitCode")
328329
}
329330

331+
func TestPasswordPrompt(t *testing.T) {
332+
prompted := false
333+
validationConsole := &testConsole{
334+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
335+
assert.Equal(t, "Password:", prompt, "Incorrect password prompt")
336+
prompted = true
337+
return []byte{}, nil
338+
},
339+
OnReadLine: func() (string, error) {
340+
assert.Fail(t, "ReadLine should not be called")
341+
return "", nil
342+
},
343+
}
344+
args = newArguments()
345+
args.InputFile = []string{"testdata/select100.sql"}
346+
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
347+
setVars(vars, &args)
348+
exitCode, err := run(vars, &args, validationConsole)
349+
assert.False(t, prompted, "Password prompt shown when not expected")
350+
assert.NoError(t, err, "run")
351+
assert.Equal(t, 0, exitCode, "exitCode")
352+
353+
args.UserName = "someuser"
354+
vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
355+
setVars(vars, &args)
356+
exitCode, err = run(vars, &args, validationConsole)
357+
assert.True(t, prompted, "Password prompt not displayed for -U")
358+
assert.Error(t, err, "run")
359+
assert.Equal(t, 1, exitCode, "exitCode")
360+
}
361+
362+
type testConsole struct {
363+
PromptText string
364+
OnPasswordPrompt func(prompt string) ([]byte, error)
365+
OnReadLine func() (string, error)
366+
}
367+
368+
func (tc *testConsole) Readline() (string, error) {
369+
return tc.OnReadLine()
370+
}
371+
372+
func (tc *testConsole) ReadPassword(prompt string) ([]byte, error) {
373+
return tc.OnPasswordPrompt(prompt)
374+
}
375+
376+
func (tc *testConsole) SetPrompt(s string) {
377+
tc.PromptText = s
378+
}
379+
380+
func (tc *testConsole) Close() {
381+
}
382+
330383
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
331384
func canTestAzureAuth() bool {
332385
server := os.Getenv(sqlcmd.SQLCMDSERVER)

0 commit comments

Comments
 (0)