Skip to content

Commit cad1251

Browse files
Adapt testcase to specific changes
1 parent 99140d4 commit cad1251

File tree

6 files changed

+31
-67
lines changed

6 files changed

+31
-67
lines changed

cmd/sqlcmd/main.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
201201
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
202202
}
203203

204-
type allocConsole func(historyFile string) sqlcmd.Console
205-
206-
var consoleAllocator allocConsole = console.NewConsole
204+
func IsConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
205+
iactive := args.InputFile == nil && args.Query == ""
206+
return iactive || connect.RequiresPassword()
207+
}
207208

208209
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
209210
wd, err := os.Getwd()
@@ -214,8 +215,8 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
214215
var connectConfig sqlcmd.ConnectSettings
215216
setConnect(&connectConfig, args, vars)
216217
var line sqlcmd.Console = nil
217-
if connectConfig.RequiresPassword() {
218-
line = consoleAllocator("")
218+
if IsConsoleInitializationRequired(&connectConfig, args) {
219+
line = console.NewConsole("")
219220
defer line.Close()
220221
}
221222

@@ -232,7 +233,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
232233
return 1, err
233234
}
234235

235-
setConnect(&s.Connect, args, vars)
236+
s.Connect = &connectConfig
236237
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
237238
if args.OutputFile != "" {
238239
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
@@ -268,7 +269,6 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
268269
}
269270

270271
iactive := args.InputFile == nil && args.Query == ""
271-
272272
if iactive || s.Query != "" {
273273
err = s.Run(once, false)
274274
} else {

cmd/sqlcmd/main_test.go

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"testing"
1010

1111
"github.com/alecthomas/kong"
12-
"github.com/microsoft/go-sqlcmd/pkg/console"
1312
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
1413
"github.com/stretchr/testify/assert"
1514
"github.com/stretchr/testify/require"
@@ -328,67 +327,31 @@ func TestMissingInputFile(t *testing.T) {
328327
assert.Equal(t, 1, exitCode, "exitCode")
329328
}
330329

331-
func TestPasswordPrompt(t *testing.T) {
332-
333-
prompted := false
334-
consoleAllocator = func(historyFile string) sqlcmd.Console {
335-
console := &testConsole{
336-
OnPasswordPrompt: func(prompt string) ([]byte, error) {
337-
assert.Equal(t, "Password:", prompt, "Incorrect password prompt")
338-
prompted = true
339-
return []byte{}, nil
340-
},
341-
OnReadLine: func() (string, error) {
342-
assert.Fail(t, "ReadLine should not be called")
343-
return "", nil
344-
},
345-
}
346-
return console
347-
}
348-
349-
args = newArguments()
330+
func TestConditionsForPasswordPrompt(t *testing.T) {
331+
args := newArguments()
350332
if canTestAzureAuth() {
351333
args.UseAad = true
352334
}
353-
args.InputFile = []string{"testdata/select100.sql"}
335+
args.InputFile = []string{"testdata/missingFile.sql"}
354336
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
355337
setVars(vars, &args)
338+
var connectConfig sqlcmd.ConnectSettings
339+
setConnect(&connectConfig, &args, vars)
340+
validateExpectedConditionsforPwdPrompt(t, &connectConfig, &args, false)
356341

357-
exitCode, err := run(vars, &args)
358-
assert.False(t, prompted, "Password prompt was not expected")
359-
assert.NoError(t, err, "run")
360-
assert.Equal(t, 0, exitCode, "exitCode")
361-
342+
args.InputFile = []string{"testdata/missingFile.sql"}
362343
args.UserName = "someuser"
363-
os.Setenv("SQLCMDPASSWORD", "")
364344
vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
365345
setVars(vars, &args)
366-
exitCode, err = run(vars, &args)
367-
assert.True(t, prompted, "Password prompt not displayed for -U")
368-
assert.Error(t, err, "run")
369-
assert.Equal(t, 1, exitCode, "exitCode")
370-
consoleAllocator = console.NewConsole
371-
}
372-
373-
type testConsole struct {
374-
PromptText string
375-
OnPasswordPrompt func(prompt string) ([]byte, error)
376-
OnReadLine func() (string, error)
377-
}
378-
379-
func (tc *testConsole) Readline() (string, error) {
380-
return tc.OnReadLine()
381-
}
382-
383-
func (tc *testConsole) ReadPassword(prompt string) ([]byte, error) {
384-
return tc.OnPasswordPrompt(prompt)
385-
}
386-
387-
func (tc *testConsole) SetPrompt(s string) {
388-
tc.PromptText = s
346+
setConnect(&connectConfig, &args, vars)
347+
connectConfig.Password = ""
348+
validateExpectedConditionsforPwdPrompt(t, &connectConfig, &args, true)
389349
}
390350

391-
func (tc *testConsole) Close() {
351+
func validateExpectedConditionsforPwdPrompt(t *testing.T, connectConfig *sqlcmd.ConnectSettings, args *SQLCmdArguments, expectedValue bool) {
352+
consoleRequired := IsConsoleInitializationRequired(connectConfig, args)
353+
pwdPromptExpected := connectConfig.RequiresPassword() && connectConfig.Password == ""
354+
assert.Equal(t, expectedValue, consoleRequired && pwdPromptExpected, "Expected condition for pwd prompt did not match")
392355
}
393356

394357
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set

pkg/sqlcmd/commands.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
339339
}
340340
connect.AuthenticationMethod = arguments.AuthenticationMethod
341341
// If no user name is provided we switch to integrated auth
342-
_ = s.ConnectDb(&connect, s.lineIo == nil)
342+
_ = s.ConnectDb(connect, s.lineIo == nil)
343343
// ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
344344
return nil
345345
}

pkg/sqlcmd/commands_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func TestConnectCommand(t *testing.T) {
167167
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
168168
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
169169
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
170-
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")
170+
assert.Equal(t, "someserver", s.Connect.ServerName, "servername should match with the input parameter")
171171

172172
err = connectCommand(s, []string{}, 2)
173173
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")

pkg/sqlcmd/sqlcmd.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type Sqlcmd struct {
6262
batch *Batch
6363
// Exitcode is returned to the operating system when the process exits
6464
Exitcode int
65-
Connect ConnectSettings
65+
Connect *ConnectSettings
6666
vars *Variables
6767
Format Formatter
6868
Query string
@@ -79,6 +79,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
7979
workingDirectory: workingDirectory,
8080
vars: vars,
8181
Cmd: newCommands(),
82+
Connect: &ConnectSettings{},
8283
}
8384
s.batch = NewBatch(s.scanNext, s.Cmd)
8485
mssql.SetContextLogger(s)
@@ -213,7 +214,7 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) {
213214
func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
214215
newConnection := connect != nil
215216
if connect == nil {
216-
connect = &s.Connect
217+
connect = s.Connect
217218
}
218219

219220
var connector driver.Connector
@@ -259,7 +260,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
259260
s.vars.Set(SQLCMDUSER, u.Username)
260261
}
261262
if newConnection {
262-
s.Connect = *connect
263+
s.Connect = connect
263264
}
264265
if s.batch != nil {
265266
s.batch.batchline = 1

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,10 @@ func TestPromptForPasswordPositive(t *testing.T) {
366366
v := InitializeVariables(true)
367367
s := New(console, "", v)
368368
// attempt without password prompt
369-
err := s.ConnectDb(&c, true)
369+
err := s.ConnectDb(c, true)
370370
assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password")
371371
assert.Error(t, err, "ConnectDb with nopw==true and no password provided")
372-
err = s.ConnectDb(&c, false)
372+
err = s.ConnectDb(c, false)
373373
assert.True(t, prompted, "ConnectDb with !nopw should prompt for password")
374374
assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt")
375375
if s.Connect.Password != password {
@@ -505,7 +505,7 @@ func canTestAzureAuth() bool {
505505
return strings.Contains(server, ".database.windows.net") && userName == ""
506506
}
507507

508-
func newConnect(t testing.TB) ConnectSettings {
508+
func newConnect(t testing.TB) *ConnectSettings {
509509
t.Helper()
510510
connect := ConnectSettings{
511511
UserName: os.Getenv(SQLCMDUSER),
@@ -517,5 +517,5 @@ func newConnect(t testing.TB) ConnectSettings {
517517
t.Log("Using ActiveDirectoryDefault")
518518
connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
519519
}
520-
return connect
520+
return &connect
521521
}

0 commit comments

Comments
 (0)