@@ -138,8 +138,8 @@ func TestRunInputFiles(t *testing.T) {
138
138
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
139
139
vars .Set (sqlcmd .SQLCMDMAXVARTYPEWIDTH , "0" )
140
140
setVars (vars , & args )
141
-
142
- exitCode , err := run (vars , & args )
141
+ sqlcmdConsole := getSqlcmdConsole ( args )
142
+ exitCode , err := run (vars , & args , sqlcmdConsole )
143
143
assert .NoError (t , err , "run" )
144
144
assert .Equal (t , 0 , exitCode , "exitCode" )
145
145
bytes , err := os .ReadFile (o .Name ())
@@ -162,8 +162,8 @@ func TestUnicodeOutput(t *testing.T) {
162
162
}
163
163
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
164
164
setVars (vars , & args )
165
-
166
- exitCode , err := run (vars , & args )
165
+ sqlcmdConsole := getSqlcmdConsole ( args )
166
+ exitCode , err := run (vars , & args , sqlcmdConsole )
167
167
assert .NoError (t , err , "run" )
168
168
assert .Equal (t , 0 , exitCode , "exitCode" )
169
169
bytes , err := os .ReadFile (o .Name ())
@@ -214,7 +214,8 @@ func TestUnicodeInput(t *testing.T) {
214
214
}
215
215
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
216
216
setVars (vars , & args )
217
- exitCode , err := run (vars , & args )
217
+ sqlcmdConsole := getSqlcmdConsole (args )
218
+ exitCode , err := run (vars , & args , sqlcmdConsole )
218
219
assert .NoError (t , err , "run" )
219
220
assert .Equal (t , 0 , exitCode , "exitCode" )
220
221
bytes , err := os .ReadFile (o .Name ())
@@ -244,8 +245,8 @@ func TestQueryAndExit(t *testing.T) {
244
245
vars .Set (sqlcmd .SQLCMDMAXVARTYPEWIDTH , "0" )
245
246
vars .Set ("VAR1" , "100" )
246
247
setVars (vars , & args )
247
-
248
- exitCode , err := run (vars , & args )
248
+ sqlcmdConsole := getSqlcmdConsole ( args )
249
+ exitCode , err := run (vars , & args , sqlcmdConsole )
249
250
assert .NoError (t , err , "run" )
250
251
assert .Equal (t , 0 , exitCode , "exitCode" )
251
252
bytes , err := os .ReadFile (o .Name ())
@@ -268,17 +269,16 @@ func TestExitOnError(t *testing.T) {
268
269
269
270
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
270
271
setVars (vars , & args )
271
-
272
- exitCode , err := run (vars , & args )
272
+ sqlcmdConsole := getSqlcmdConsole ( args )
273
+ exitCode , err := run (vars , & args , sqlcmdConsole )
273
274
assert .NoError (t , err , "run" )
274
275
assert .Equal (t , 0 , exitCode , "exitCode" )
275
276
276
277
args .InputFile = []string {"testdata/bad.sql" }
277
278
278
279
vars = sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
279
280
setVars (vars , & args )
280
-
281
- exitCode , err = run (vars , & args )
281
+ exitCode , err = run (vars , & args , sqlcmdConsole )
282
282
assert .NoError (t , err , "run" )
283
283
assert .Equal (t , 1 , exitCode , "exitCode" )
284
284
@@ -301,7 +301,8 @@ func TestAzureAuth(t *testing.T) {
301
301
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
302
302
vars .Set (sqlcmd .SQLCMDMAXVARTYPEWIDTH , "0" )
303
303
setVars (vars , & args )
304
- exitCode , err := run (vars , & args )
304
+ sqlcmdConsole := getSqlcmdConsole (args )
305
+ exitCode , err := run (vars , & args , sqlcmdConsole )
305
306
assert .NoError (t , err , "run" )
306
307
assert .Equal (t , 0 , exitCode , "exitCode" )
307
308
bytes , err := os .ReadFile (o .Name ())
@@ -320,13 +321,65 @@ func TestMissingInputFile(t *testing.T) {
320
321
321
322
vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
322
323
setVars (vars , & args )
323
-
324
- exitCode , err := run (vars , & args )
324
+ sqlcmdConsole := getSqlcmdConsole ( args )
325
+ exitCode , err := run (vars , & args , sqlcmdConsole )
325
326
assert .Error (t , err , "run" )
326
327
assert .Contains (t , err .Error (), "Error occurred while opening or operating on file" , "Unexpected error: " + err .Error ())
327
328
assert .Equal (t , 1 , exitCode , "exitCode" )
328
329
}
329
330
331
+ func TestPasswordPrompt (t * testing.T ) {
332
+ prompted := false
333
+ validationConsole := & testConsole {
334
+ OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
335
+ assert .Equal (t , "Password:" , prompt , "Incorrect password prompt" )
336
+ prompted = true
337
+ return []byte {}, nil
338
+ },
339
+ OnReadLine : func () (string , error ) {
340
+ assert .Fail (t , "ReadLine should not be called" )
341
+ return "" , nil
342
+ },
343
+ }
344
+ args = newArguments ()
345
+ args .InputFile = []string {"testdata/select100.sql" }
346
+ vars := sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
347
+ setVars (vars , & args )
348
+ exitCode , err := run (vars , & args , validationConsole )
349
+ assert .False (t , prompted , "Password prompt shown when not expected" )
350
+ assert .NoError (t , err , "run" )
351
+ assert .Equal (t , 0 , exitCode , "exitCode" )
352
+
353
+ args .UserName = "someuser"
354
+ vars = sqlcmd .InitializeVariables (! args .DisableCmdAndWarn )
355
+ setVars (vars , & args )
356
+ exitCode , err = run (vars , & args , validationConsole )
357
+ assert .True (t , prompted , "Password prompt not displayed for -U" )
358
+ assert .Error (t , err , "run" )
359
+ assert .Equal (t , 1 , exitCode , "exitCode" )
360
+ }
361
+
362
+ type testConsole struct {
363
+ PromptText string
364
+ OnPasswordPrompt func (prompt string ) ([]byte , error )
365
+ OnReadLine func () (string , error )
366
+ }
367
+
368
+ func (tc * testConsole ) Readline () (string , error ) {
369
+ return tc .OnReadLine ()
370
+ }
371
+
372
+ func (tc * testConsole ) ReadPassword (prompt string ) ([]byte , error ) {
373
+ return tc .OnPasswordPrompt (prompt )
374
+ }
375
+
376
+ func (tc * testConsole ) SetPrompt (s string ) {
377
+ tc .PromptText = s
378
+ }
379
+
380
+ func (tc * testConsole ) Close () {
381
+ }
382
+
330
383
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
331
384
func canTestAzureAuth () bool {
332
385
server := os .Getenv (sqlcmd .SQLCMDSERVER )
0 commit comments