diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e13bed0..6d112bf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: run-tests: strategy: matrix: - go: ['1.20'] + go: ['1.23'] platform: [ubuntu-latest] runs-on: ubuntu-latest diff --git a/go.mod b/go.mod index 430e4e2..722557b 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module gorm.io/driver/postgres -go 1.20 +go 1.23.0 require ( - github.com/jackc/pgx/v5 v5.6.0 - gorm.io/gorm v1.25.10 + github.com/jackc/pgx/v5 v5.7.5 + gorm.io/gorm v1.30.0 ) require ( @@ -13,9 +13,9 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/sync v0.10.0 // indirect - golang.org/x/text v0.21.0 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/text v0.27.0 // indirect ) retract v1.5.5 // Published accidentally. diff --git a/go.sum b/go.sum index 50dd830..45eada3 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,12 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= -github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= +github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -18,14 +19,16 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= +gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/in_optimiser.go b/in_optimiser.go new file mode 100644 index 0000000..ae8c3aa --- /dev/null +++ b/in_optimiser.go @@ -0,0 +1,185 @@ +package postgres + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + + "gorm.io/gorm/clause" +) + +// Rewrites conditions of WHERE clauses to replace `col IN (?)` and `col NOT IN (?)` +// with `col = ANY(?)` and `col != ALL(?)`, respectively. The difference between +// the two forms is in their interplay with prepared statements: +// +// 1. A condition .Where("col IN (?)", values) expands to `col IN ($1,$2,...)` +// where the list has len(values) items. Every value[i] is sent to postgres +// as a separate query argument. +// 2. A condition .Where("col = ANY(?)", values) always expands to `col = ANY($1)`, +// and values are sent to postgres as exactly one query argument (of array type). +// +// The option 1 does not iteract well with prepared statements. It produces +// a different query for different len(values). Option 2, on the other hand, +// needs only one prepared statement for any len(values). +func rewriteWhereClauses(e clause.Expression) clause.Expression { + var r inClausesRewriter + return r.rewriteExpression(e) +} + +type inClausesRewriter struct{} + +func (r inClausesRewriter) rewriteExpression(e clause.Expression) clause.Expression { + switch e := e.(type) { + case clause.Expr: + return r.rewriteExpr(e) + case clause.NamedExpr: + return r.rewriteNamedExpr(e) + case clause.IN: + return r.rewriteInExpr(e) + + case clause.Where: + return clause.Where{Exprs: r.rewriteArray(e.Exprs)} + case clause.OrConditions: + return clause.OrConditions{Exprs: r.rewriteArray(e.Exprs)} + case clause.AndConditions: + return clause.AndConditions{Exprs: r.rewriteArray(e.Exprs)} + case clause.NotConditions: + return clause.NotConditions{Exprs: r.rewriteArray(e.Exprs)} + + default: + return e + } +} + +func (r inClausesRewriter) rewriteArray(in []clause.Expression) (out []clause.Expression) { + out = make([]clause.Expression, len(in)) + for i := range in { + out[i] = r.rewriteExpression(in[i]) + } + return out +} + +var ( + // NOTE: this does not exactly follow the SQL syntax. Quoted identifier names + // may contain any non-NULL characters, but we only allow \w = [a-z0-9_]. + // Also, these regexps allow column names like "123abc". This does not matter + // because postgres will check the syntax, anyway. + columnInRe = regexp.MustCompile(`(?i)^\s*((\w+\.)|("\w+"\.))?((\w+)|("\w+"))\s+in\s*\(\?\)\s*$`) + columnNotInRe = regexp.MustCompile(`(?i)^\s*((\w+\.)|("\w+"\.))?((\w+)|("\w+"))\s+not\s+in\s*\(\?\)\s*$`) +) + +func (r inClausesRewriter) rewriteExpr(in clause.Expr) (out clause.Expr) { + mIn := columnInRe.FindStringSubmatch(in.SQL) + mNotIn := columnNotInRe.FindStringSubmatch(in.SQL) + if mIn == nil && mNotIn == nil { + return in + } + if len(in.Vars) != 1 { + return in + } + + vars := r.rewriteExprVar(in.Vars[0]) + if vars == nil { + return in + } + + if mIn != nil { + return clause.Expr{ + SQL: fmt.Sprintf("%s%s = ANY(?)", mIn[1], mIn[4]), + Vars: []any{passthroughValuer{vars}}, + WithoutParentheses: in.WithoutParentheses, + } + } else { + return clause.Expr{ + SQL: fmt.Sprintf("%s%s != ALL(?)", mNotIn[1], mNotIn[4]), + Vars: []any{passthroughValuer{vars}}, + WithoutParentheses: in.WithoutParentheses, + } + } +} + +func (r inClausesRewriter) rewriteNamedExpr(in clause.NamedExpr) (out clause.NamedExpr) { + e := r.rewriteExpr(clause.Expr{SQL: in.SQL, Vars: in.Vars}) + return clause.NamedExpr{SQL: e.SQL, Vars: e.Vars} +} + +func (r inClausesRewriter) rewriteInExpr(in clause.IN) (out clause.Expression) { + values := r.rewriteInValues(in.Values) + if values != nil { + return EqANY{Column: in.Column, Values: passthroughValuer{values}} + } else { + return in + } +} + +func (r inClausesRewriter) rewriteExprVar(in any) (out any) { + v := reflect.ValueOf(in) + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + return nil + } + + if r.isSupportedArrayElementType(v.Type().Elem()) { + return in + } else { + return nil + } +} + +func (r inClausesRewriter) rewriteInValues(in []any) (out []any) { + if len(in) == 0 { + return in + } + + et := reflect.TypeOf(in[0]) + for _, v := range in { + if reflect.TypeOf(v) != et { + return nil + } + } + + if r.isSupportedArrayElementType(et) { + return in + } else { + return nil + } +} + +func (r inClausesRewriter) isSupportedArrayElementType(t reflect.Type) bool { + k := t.Kind() + if k == reflect.Bool || + k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || k == reflect.Int32 || k == reflect.Int64 || + k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 || k == reflect.Uint32 || k == reflect.Uint64 || + k == reflect.Float32 || k == reflect.Float64 || + k == reflect.String { + return true + } + if t.Implements(reflect.TypeFor[driver.Valuer]()) { + return true + } + return false +} + +// This driver.Valuer hides arrays from GORM. Statement's AddVar() expands arrays +// into multiple values in the query, and binds each of them as a separate query +// argument. We need the whole array to be one query argument. +type passthroughValuer struct { + val any +} + +func (v passthroughValuer) Value() (driver.Value, error) { + return v.val, nil +} + +type EqANY struct { + Column any + Values any +} + +func (eqANY EqANY) Build(builder clause.Builder) { + builder.WriteQuoted(eqANY.Column) + + builder.WriteString(" = ANY(") + builder.AddVar(builder, eqANY.Values) + builder.WriteByte(')') +} diff --git a/in_optimiser_test.go b/in_optimiser_test.go new file mode 100644 index 0000000..38f7995 --- /dev/null +++ b/in_optimiser_test.go @@ -0,0 +1,111 @@ +package postgres + +import "testing" + +func TestColumnInRe(t *testing.T) { + type validExample struct { + e string + tbl string + col string + } + validExamples := []validExample{ + {e: `c in (?)`, tbl: ``, col: `c`}, + {e: `c IN (?)`, tbl: ``, col: `c`}, + {e: ` c in (?) `, tbl: ``, col: `c`}, + {e: `"c" in (?)`, tbl: ``, col: `"c"`}, + {e: `tbl.c in (?)`, tbl: `tbl.`, col: `c`}, + {e: `tbl."c" in (?)`, tbl: `tbl.`, col: `"c"`}, + {e: `"tbl".c in (?)`, tbl: `"tbl".`, col: `c`}, + {e: `"tbl"."c" in (?)`, tbl: `"tbl".`, col: `"c"`}, + {e: `column_name in (?)`, tbl: ``, col: `column_name`}, + {e: `abc123 in (?)`, tbl: ``, col: `abc123`}, + } + + for _, e := range validExamples { + t.Run(e.e, func(t *testing.T) { + m := columnInRe.FindStringSubmatch(e.e) + if m == nil { + t.Fatalf("must be a valid IN expression: %q", e.e) + } + if len(m) != 1+6 { + t.Fatalf("columnInRe is expected to have 6 capture groups") + } + if m[1] != e.tbl || m[4] != e.col { + t.Fatalf("columnInRe fails to capture the table and column names") + } + }) + } + + invalidExamples := []string{ + `tbl c in (?)`, + `c not in (?)`, + `tbl.c = ANY(?)`, + `column-name in (?)`, + // NOTE: this one is a valid escaped column name (it may contain + // any characters except NULL), but let us not handle this case. + `"tbl.c" in (?)`, + } + + for _, e := range invalidExamples { + t.Run(e, func(t *testing.T) { + if columnInRe.MatchString(e) { + t.Fatalf("must be a invalid IN expression: %q", e) + } + }) + } +} + +func TestColumnNotInRe(t *testing.T) { + type validExample struct { + e string + tbl string + col string + } + validExamples := []validExample{ + {e: `c not in (?)`, tbl: ``, col: `c`}, + {e: `c not IN (?)`, tbl: ``, col: `c`}, + {e: ` c NOT in (?) `, tbl: ``, col: `c`}, + {e: `"c" not in (?)`, tbl: ``, col: `"c"`}, + {e: `tbl.c not in (?)`, tbl: `tbl.`, col: `c`}, + {e: `tbl."c" not in (?)`, tbl: `tbl.`, col: `"c"`}, + {e: `"tbl".c not in (?)`, tbl: `"tbl".`, col: `c`}, + {e: `"tbl"."c" not in (?)`, tbl: `"tbl".`, col: `"c"`}, + {e: `column_name not in (?)`, tbl: ``, col: `column_name`}, + {e: `abc123 not in (?)`, tbl: ``, col: `abc123`}, + } + + for _, e := range validExamples { + t.Run(e.e, func(t *testing.T) { + m := columnNotInRe.FindStringSubmatch(e.e) + if m == nil { + t.Fatalf("must be a valid NOT IN expression: %q", e.e) + } + if len(m) != 1+6 { + t.Fatalf("columnNotInRe is expected to have 6 capture groups") + } + if m[1] != e.tbl || m[4] != e.col { + t.Fatalf("columnNotInRe fails to capture the table and column names") + } + }) + } + + invalidExamples := []string{ + `tbl c not in (?)`, + `c in (?)`, + `c not (?)`, + `c in not (?)`, + `tbl.c != ALL(?)`, + `column-name not in (?)`, + // NOTE: this one is a valid escaped column name (it may contain + // any characters except NULL), but let us not handle this case. + `"tbl.c" not in (?)`, + } + + for _, e := range invalidExamples { + t.Run(e, func(t *testing.T) { + if columnNotInRe.MatchString(e) { + t.Fatalf("must be a invalid NOT IN expression: %q", e) + } + }) + } +} diff --git a/postgres.go b/postgres.go index 2d8fd99..edb42b9 100644 --- a/postgres.go +++ b/postgres.go @@ -120,6 +120,33 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } db.ConnPool = stdlib.OpenDB(*config, options...) } + if err != nil { + return + } + + db.ClauseBuilders = make(map[string]clause.ClauseBuilder) + + db.ClauseBuilders["WHERE"] = func(c clause.Clause, b clause.Builder) { + c.Expression = rewriteWhereClauses(c.Expression) + c.Build(b) + } + + if !dialector.Config.PreferSimpleProtocol { + // When inserting N rows with a bulk INSERT INTO, GORM produces a + // unique query string for every value of N. This means a prepared + // statement must be produced for every N, which overflows pgx's + // cache of prepared statements quickly. + // + // Just do not use prepared statements for bulk inserts. + db.ClauseBuilders["VALUES"] = func(c clause.Clause, b clause.Builder) { + values := c.Expression.(clause.Values) + if len(values.Values) > 1 { + b.AddVar(b, sql.NamedArg{Value: pgx.QueryExecModeDescribeExec}) + } + c.Build(b) + } + } + return }