Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ unreleased

### Features

- Implement `require_auth` connection parameter ([#1310]).

### Fixes

- Add Redshift-specific OID mappings ([#1291], [#1317]).

- Use correct environment variable name for `PGSSLMINPROTOCOLVERSION` and
`PGSSLMAXPROTOCOLVERSION` ([#1310]).

[#1291]: https://github.com/lib/pq/pull/1291
[#1310]: https://github.com/lib/pq/pull/1310
[#1317]: https://github.com/lib/pq/pull/1317


Expand Down
24 changes: 21 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net"
"os"
"reflect"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -1255,6 +1256,7 @@ func (cn *conn) startup(cfg Config) error {
return err
}

var didauth bool
for {
t, r, err := cn.recv()
if err != nil {
Expand All @@ -1271,7 +1273,11 @@ func (cn *conn) startup(cfg Config) error {
case proto.ParameterStatus:
cn.processParameterStatus(r)
case proto.AuthenticationRequest:
err := cn.auth(r, cfg)
code := proto.AuthCode(r.int32())
if code != proto.AuthReqOk {
didauth = true
}
err := cn.auth(code, r, cfg)
if err != nil {
return err
}
Expand All @@ -1282,6 +1288,9 @@ func (cn *conn) startup(cfg Config) error {
return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor)
}
case proto.ReadyForQuery:
if len(cn.cfg.RequireAuth) > 0 && !didauth && !slices.Contains(cn.cfg.RequireAuth, RequireAuthNone) {
return fmt.Errorf("pq: authentication method requirement %q failed: server did not perform any authentication", cn.cfg.RequireAuth)
}
cn.processReadyForQuery(r)
return nil
default:
Expand All @@ -1290,8 +1299,8 @@ func (cn *conn) startup(cfg Config) error {
}
}

func (cn *conn) auth(r *readBuf, cfg Config) error {
switch code := proto.AuthCode(r.int32()); code {
func (cn *conn) auth(code proto.AuthCode, r *readBuf, cfg Config) error {
switch code {
default:
return fmt.Errorf("pq: unknown authentication response: %s", code)
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
Expand All @@ -1300,13 +1309,19 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
return nil

case proto.AuthReqPassword:
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthPassword) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthPassword)
}
w := cn.writeBuf(proto.PasswordMessage)
w.string(cfg.Password)
// Don't need to check AuthOk response here; auth() is called in a loop,
// which catches the errors and AuthReqOk responses.
return cn.send(w)

case proto.AuthReqMD5:
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthMD5) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthMD5)
}
s := string(r.next(4))
w := cn.writeBuf(proto.PasswordMessage)
w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s))
Expand Down Expand Up @@ -1369,6 +1384,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
return nil

case proto.AuthReqSASL:
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthScramSHA256) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthScramSHA256)
}
sc := scram.NewClient(sha256.New, cfg.User, cfg.Password)
sc.Step(nil)
if sc.Err() != nil {
Expand Down
32 changes: 28 additions & 4 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1607,18 +1607,18 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {

func TestAuth(t *testing.T) {
tests := []struct {
buf readBuf
code proto.AuthCode
wantErr string
}{
{readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`},
{readBuf{0, 0, 0, 99}, `unknown authentication response: <unknown> (99)`},
{proto.AuthCode(9), `pq: unsupported authentication method: SSPI (9)`},
{proto.AuthCode(99), `unknown authentication response: <unknown> (99)`},
}

t.Parallel()
for _, tt := range tests {
t.Run("", func(t *testing.T) {
t.Run("unsupported auth", func(t *testing.T) {
err := (&conn{}).auth(&tt.buf, Config{})
err := (&conn{}).auth(tt.code, &readBuf{}, Config{})
if !pqtest.ErrorContains(err, tt.wantErr) {
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
}
Expand Down Expand Up @@ -1646,10 +1646,34 @@ func TestAuth(t *testing.T) {
{"user=pqgoscram password=wordpass", ``},

{"user=pqgounknown password=wordpass", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},
{"user=pqgounknown password=wordpass require_auth=md5", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},

// require_auth
{"user=pqgomd5 password=wordpass require_auth=md5,password", ``},
{"user=pqgopassword password=wordpass require_auth=md5,password", ``},
{"user=pqgoscram password=wordpass require_auth=md5,password,scram-sha-256", ``},
{"user=pqgomd5 password=wordpass require_auth=!none", ``},
{"user=pqgopassword password=wordpass require_auth=!none", ``},
{"user=pqgoscram password=wordpass require_auth=!none", ``},

{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},
{"user=pqgopassword password=wordpass require_auth=md5", `"md5" failed: server requested "password"`},
{"user=pqgoscram password=wordpass require_auth=md5,password", `authentication method requirement "md5,password" failed: server requested "scram-sha-256"`},
{"user=pqgomd5 password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "md5"`},
{"user=pqgopassword password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "password"`},
{"user=pqgoscram password=wordpass require_auth=!md5,!password,!scram-sha-256", `"!md5,!password,!scram-sha-256" failed: server requested "scram-sha-256"`},
{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},

{"user=pqgo password=unused require_auth=none", ``},
{"user=pqgo password=unused require_auth=!none", `"!none" failed: server did not perform any authentication`},
{"user=pqgo password=unused require_auth=md5,password,scram-sha-256", `"md5,password,scram-sha-256" failed: server did not perform any authentication`},
}

for _, tt := range tests {
t.Run(tt.conn, func(t *testing.T) {
if strings.Contains(tt.conn, "md5") {
pqtest.SkipCockroach(t) // md5 not supported
}
_, err := pqtest.DB(t, tt.conn)
if !pqtest.ErrorContains(err, tt.wantErr) {
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
Expand Down
101 changes: 94 additions & 7 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ type (
// SSLProtocolVersion is a ssl_min_protocol_version or
// ssl_max_protocol_version setting.
SSLProtocolVersion string

// RequireAuth is a require_auth setting.
RequireAuth string

// RequireAuths is a require_auth setting.
RequireAuths []RequireAuth
)

// Values for [SSLMode] that pq supports.
Expand Down Expand Up @@ -179,6 +185,41 @@ func (s SSLProtocolVersion) tlsconf() uint16 {
}
}

// Values for [RequireAuth] that pq supports.
const (
RequireAuthNone = RequireAuth("none")
RequireAuthPassword = RequireAuth("password")
RequireAuthMD5 = RequireAuth("md5")
RequireAuthGSS = RequireAuth("gss")
RequireAuthScramSHA256 = RequireAuth("scram-sha-256")
RequireAuthAny = RequireAuth("!none")
RequireAuthNotPassword = RequireAuth("!password")
RequireAuthNotMD5 = RequireAuth("!md5")
RequireAuthNotGSS = RequireAuth("!gss")
RequireAuthNotScramSHA256 = RequireAuth("!scram-sha-256")

// Not (yet) supported by pq
// RequireAuthSSPI = "sspi"
// RequireAuthOAuth = "oauth"
// RequireAuthNotSSPI = "!sspi"
// RequireAuthNotOAuth = "!oauth"
)

var requireAuths = []RequireAuth{RequireAuthNone, RequireAuthPassword, RequireAuthMD5,
RequireAuthGSS, RequireAuthScramSHA256, RequireAuthAny, RequireAuthNotPassword,
RequireAuthNotMD5, RequireAuthNotGSS, RequireAuthNotScramSHA256}

func (r RequireAuths) String() string {
var b strings.Builder
for i, rr := range r {
if i > 0 {
b.WriteString(",")
}
b.WriteString(string(rr))
}
return b.String()
}

// Connector represents a fixed configuration for the pq driver with a given
// dsn. Connector satisfies the [database/sql/driver.Connector] interface and
// can be used to create any number of DB Conn's via [sql.OpenDB].
Expand Down Expand Up @@ -341,14 +382,14 @@ type Config struct {
//
// The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at
// the time of writing.
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"`
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"PGSSLMINPROTOCOLVERSION"`

// Maximum SSL/TLS protocol version to allow for the connection. If not set,
// this parameter is ignored and the connection will use the maximum bound
// defined by the backend, if set. Setting the maximum protocol version is
// mainly useful for testing or if some component has issues working with a
// newer protocol.
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"`
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"PGSSLMAXPROTOCOLVERSION"`

// Interpert sslcert and sslkey as PEM encoded data, rather than a path to a
// PEM file. This is a pq extension, not supported in libpq.
Expand Down Expand Up @@ -431,6 +472,25 @@ type Config struct {
// Path to connection service file. Defaults to ~/.pg_service.conf.
ServiceFile string `postgres:"-" env:"PGSERVICEFILE"`

// Require an authentication method from the server and refuse to connect if
// the server does not use the requested method.
//
// This accepts a comma-separated list.
//
// Methods may be negated with a ! prefix, in which case the server must
// *not* attempt the listed method, and the server is free not to
// authenticate the client at all. Negated and non-negated forms may not be
// combined in the same setting with a comma-separated list.
//
// As a special case the "none" method requires the server not to use an
// authentication challenge. This does not prohibit client certificate
// authentication via TLS or GSS authentication via its encrypted transport.
// This can be negated to require some form of authentication.
//
// By default any authentication method is accepted and the server is free
// to skip authentication altogether.
RequireAuth RequireAuths `postgres:"require_auth" env:"PGREQUIREAUTH"`

// Runtime parameters: any unrecognized parameter in the DSN will be added
// to this and sent to PostgreSQL during startup.
Runtime map[string]string `postgres:"-" env:"-"`
Expand Down Expand Up @@ -517,7 +577,8 @@ func NewConfig(dsn string) (Config, error) {
// Clone returns a copy of the [Config].
func (cfg Config) Clone() Config {
c := cfg
c.Runtime, c.Multi, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi), slices.Clone(cfg.set)
c.Runtime, c.Multi, c.RequireAuth, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi),
slices.Clone(cfg.RequireAuth), slices.Clone(cfg.set)
return c
}

Expand Down Expand Up @@ -672,8 +733,8 @@ func (cfg *Config) fromEnv(env []string) error {
switch k {
case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated.
"PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff
"PGREQUIREAUTH", "PGCHANNELBINDING",
"PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER":
"PGCHANNELBINDING", "PGSSLCRL", "PGSSLCRLDIR",
"PGSSLCERTMODE", "PGREQUIREPEER":
return fmt.Errorf("pq: environment variable $%s is not supported", k)
case "PGKRBSRVNAME":
if newGss == nil {
Expand Down Expand Up @@ -833,8 +894,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS")
minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION")
maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION")
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION")
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION")
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "PGSSLMINPROTOCOLVERSION")
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "PGSSLMAXPROTOCOLVERSION")
requireauth = (tag == "postgres" && k == "require_auth") || (tag == "env" && k == "PGREQUIREAUTH")
)
if k == "" || k == "-" {
continue
Expand Down Expand Up @@ -908,6 +970,31 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
cfg.multiHost = append(cfg.multiHost, vv[1:]...)
}
rv.SetString(v)
case reflect.Slice:
if requireauth {
if v == "" {
rv.Set(reflect.ValueOf((RequireAuths)(nil)))
continue
}
var (
vv = strings.Split(v, ",")
s = make(RequireAuths, len(vv))
neg = len(vv) > 0 && strings.HasPrefix(vv[0], "!")
)
for i := range vv {
if !slices.Contains(requireAuths, RequireAuth(vv[i])) {
return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, vv[i], pqutil.Join(requireAuths))
}
if neg && !strings.HasPrefix(vv[i], "!") {
return fmt.Errorf(f+`require_auth method %q cannot be mixed with negative methods`, k, vv[i])
}
if !neg && strings.HasPrefix(vv[i], "!") {
return fmt.Errorf(f+`negative require_auth method %q cannot be mixed with non-negative methods`, k, vv[i])
}
s[i] = RequireAuth(vv[i])
}
rv.Set(reflect.ValueOf(s))
}
case reflect.Int64:
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,15 @@ func TestNewConfig(t *testing.T) {
{"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`},
{"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`},
{"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`},

// requireauth
{"require_auth=", nil, "require_auth=''", ``},
{"require_auth=none", nil, "require_auth=none", ""},
{"require_auth=md5,scram-sha-256", nil, "require_auth=md5,scram-sha-256", ""},
{"require_auth=md5,scram-sha256", nil, "", `wrong value for "require_auth": "scram-sha256" is not supported`},
{"require_auth=!md5,!scram-sha-256", nil, "require_auth=!md5,!scram-sha-256", ""},
{"require_auth=md5,!password", nil, "", `negative require_auth method "!password" cannot be mixed with non-negative methods`},
{"require_auth=!md5,password", nil, "", `require_auth method "password" cannot be mixed with negative methods`},
}

t.Parallel()
Expand Down