Skip to content

Commit 12e464c

Browse files
committed
Allow multiple matches and regexps in pqtest.ErrorContains()
This makes it a bit easier to test against different servers that return different messages.
1 parent 6d77ced commit 12e464c

4 files changed

Lines changed: 38 additions & 44 deletions

File tree

connector_test.go

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,17 @@ func TestNewConnector(t *testing.T) {
115115
})
116116

117117
t.Run("database=", func(t *testing.T) {
118-
want1, want2 := `pq: database "err" does not exist (3D000)`,
119-
`pq: database "two" does not exist (3D000)`
120-
if pqtest.Pgbouncer() {
121-
want1, want2 = `pq: no such database: err (08P01)`, `pq: no such database: two (08P01)`
122-
}
118+
want1, want2 := `or:pq: database "err" does not exist (3D000)|pq: no such database: err (08P01)`,
119+
`or:pq: database "two" does not exist (3D000)|pq: no such database: two (08P01)`
123120

124121
// Make sure database= consistently take precedence over dbname=
125122
for i := 0; i < 10; i++ {
126123
_, err := pqtest.DB(t, "database=err")
127-
if err == nil || err.Error() != want1 {
124+
if !pqtest.ErrorContains(err, want1) {
128125
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, want1)
129126
}
130127
_, err = pqtest.DB(t, "dbname=one database=two")
131-
if err == nil || err.Error() != want2 {
128+
if !pqtest.ErrorContains(err, want2) {
132129
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, want2)
133130
}
134131
}
@@ -198,7 +195,7 @@ func TestRuntimeParameters(t *testing.T) {
198195
wantErr string
199196
skipPgbouncer bool
200197
}{
201-
{"DOESNOTEXIST=foo", "", "", "unrecognized configuration parameter", false},
198+
{"DOESNOTEXIST=foo", "", "", `or:unrecognized configuration parameter|unsupported startup parameter`, false},
202199

203200
// we can only work with a specific value for these two
204201
{"client_encoding=SQL_ASCII", "", "", `unsupported client_encoding "SQL_ASCII": must be absent or "UTF8"`, false},
@@ -233,9 +230,6 @@ func TestRuntimeParameters(t *testing.T) {
233230
if tt.skipPgbouncer {
234231
pqtest.SkipPgbouncer(t)
235232
}
236-
if pqtest.Pgbouncer() && tt.wantErr == "unrecognized configuration parameter" {
237-
tt.wantErr = `unsupported startup parameter`
238-
}
239233

240234
db, err := pqtest.DB(t, tt.conninfo)
241235
if !pqtest.ErrorContains(err, tt.wantErr) {

internal/pqtest/ztest.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
package pqtest
44

55
import (
6+
"fmt"
67
"io/fs"
78
"os"
89
"path/filepath"
10+
"regexp"
911
"strings"
1012
"testing"
1113
)
@@ -15,13 +17,35 @@ import (
1517
//
1618
// This is safe when have is nil. Use an empty string for want if you want to
1719
// test that err is nil.
20+
//
21+
// Uses a regexp match if want starts with "re:".
22+
//
23+
// Matches several strings if wants starts with "or:". For example
24+
// "or:one|two|three" matches any error which contains "one", "two", or "three".
25+
// This is similar to "re:(one|two|three)" except that it's not required to
26+
// escape regexp meta characters. It's not possible to escape | with or:
1827
func ErrorContains(have error, want string) bool {
1928
if have == nil {
2029
return want == ""
2130
}
2231
if want == "" {
2332
return false
2433
}
34+
if strings.HasPrefix(want, "re:") {
35+
m, err := regexp.MatchString(want[3:], have.Error())
36+
if err != nil {
37+
panic(fmt.Errorf("pqtest.ErrorContains: %w", err))
38+
}
39+
return m
40+
}
41+
if strings.HasPrefix(want, "or:") {
42+
for _, w := range strings.Split(want[3:], "|") {
43+
if strings.Contains(have.Error(), w) {
44+
return true
45+
}
46+
}
47+
return false
48+
}
2549
return strings.Contains(have.Error(), want)
2650
}
2751

notify_test.go

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"database/sql"
66
"database/sql/driver"
77
"fmt"
8-
"io"
98
"math/big"
109
"net"
1110
"runtime"
@@ -360,23 +359,15 @@ func TestListenerReconnect(t *testing.T) {
360359
wantNotification(t, l.Notify, n, "")
361360

362361
// Kill the connection and make sure it comes back up.
363-
ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())")
362+
ok, err := l.cn.ExecSimpleQuery("select pg_terminate_backend(pg_backend_pid())")
364363
if ok {
365364
t.Fatalf("could not kill the connection: %v", err)
366365
}
367-
if pqtest.Pgbouncer() {
368-
if !pqtest.ErrorContains(err, "server conn crashed") {
369-
t.Fatalf("unexpected error %T: %[1]s", err)
370-
}
371-
} else if pqtest.Pgpool() {
372-
if !pqtest.ErrorContains(err, "unable to forward message to frontend") {
373-
t.Fatalf("unexpected error %T: %[1]s", err)
374-
}
375-
} else {
376-
if err != io.EOF {
377-
t.Fatalf("unexpected error %T: %[1]s", err)
378-
}
366+
// PostgreSQL, pgbouncer, and pgpool all use different errors.
367+
if !pqtest.ErrorContains(err, `or:EOF|pq: server conn crashed? (08P01)|pq: unable to forward message to frontend (XX000)`) {
368+
t.Fatalf("unexpected error %T: %[1]s", err)
379369
}
370+
380371
wantEvent(t, ch, ListenerEventDisconnected)
381372
wantEvent(t, ch, ListenerEventReconnected)
382373

ssl_test.go

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,9 @@ import (
2020

2121
// Environment sanity check: should fail without SSL.
2222
func startSSLTest(t *testing.T, user string) {
23-
wantErr := `invalid_authorization_specification`
24-
if pqtest.Pgbouncer() {
25-
wantErr = "protocol_violation"
26-
} else if pqtest.Pgpool() {
27-
wantErr = "internal_error"
28-
}
2923
_, err := pqtest.DB(t, "sslmode=disable user="+user)
30-
pqErr := mustAs(t, err)
31-
if pqErr.Code.Name() != wantErr {
32-
t.Fatalf("wrong error code %q", pqErr.Code.Name())
24+
if !pqtest.ErrorContains(err, `or:no encryption (28000)|login rejected (08P01)`) {
25+
t.Fatalf("wrong error: %q", err)
3326
}
3427
}
3528

@@ -94,7 +87,8 @@ func TestSSLMode(t *testing.T) {
9487
{"sslmode=allow " + f.DSN(), ""}, // Idem
9588

9689
// sslmode=disable
97-
{"sslmode=disable user=pqgossl", "no encryption"},
90+
// pgbouncer uses "login rejected (08P01)", so allow that too.
91+
{"sslmode=disable user=pqgossl", "or:no encryption|login rejected (08P01)"},
9892

9993
// sslnegotiation=direct should fail if ssl isn't required, like libpq:
10094
// psql: error: weak sslmode "allow" may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full")
@@ -108,15 +102,6 @@ func TestSSLMode(t *testing.T) {
108102
t.Run("", func(t *testing.T) {
109103
t.Parallel()
110104

111-
if tt.wantErr == "no encryption" && pqtest.Pgbouncer() {
112-
// PostgreSQL repsonds with:
113-
// pq: pg_hba.conf rejects connection for host "172.18.0.1", user "pqgossl", database "pqgo", no encryption (28000)
114-
//
115-
// But pgbouncer has a different message and code:
116-
// pq: login rejected (08P01)
117-
tt.wantErr = "login rejected"
118-
}
119-
120105
_, err := pqtest.DB(t, tt.connect)
121106
switch {
122107
case tt.wantErr == "" && err != nil:

0 commit comments

Comments
 (0)