Skip to content

Commit a9b2048

Browse files
kvchandrew-farries
andauthored
Add support for creating foreign key constraints using create_constraint (#471)
This PR introduces a new constraint `type` to `create_constraint` operation called `foreign_key`. Now it is possible to create FK constraints on multiple columns. ### Examples #### Foreign key ```json { "name": "44_add_foreign_key_table_reference_constraint", "operations": [ { "create_constraint": { "type": "foreign_key", "table": "tickets", "name": "fk_sellers", "columns": [ "sellers_name", "sellers_zip" ], "references": { "table": "sellers", "columns": [ "name", "zip" ], "on_delete": "CASCADE" }, "up": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" }, "down": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" } } } ] } ``` Closes #81 --------- Co-authored-by: Andrew Farries <[email protected]>
1 parent 81dd0a7 commit a9b2048

13 files changed

+420
-43
lines changed

docs/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ Example **create table** migrations:
11021102

11031103
A create constraint operation adds a new constraint to an existing table.
11041104

1105-
Only `UNIQUE` and `CHECK` constraints are supported.
1105+
`UNIQUE`, `CHECK` and `FOREIGN KEY` constraints are supported.
11061106

11071107
Required fields: `name`, `table`, `type`, `up`, `down`.
11081108

@@ -1114,7 +1114,14 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
11141114
"table": "name of table",
11151115
"name": "my_unique_constraint",
11161116
"columns": ["col1", "col2"],
1117-
"type": "unique"
1117+
"type": "unique"| "check" | "foreign_key",
1118+
"check": "SQL expression for CHECK constraint",
1119+
"references": {
1120+
"name": "name of foreign key reference",
1121+
"table": "name of referenced table",
1122+
"columns": "[names of referenced columns]",
1123+
"on_delete": "ON DELETE behaviour, can be CASCADE, SET NULL, RESTRICT, or NO ACTION. Default is NO ACTION",
1124+
},
11181125
"up": {
11191126
"col1": "col1 || random()",
11201127
"col2": "col2 || random()"
@@ -1131,7 +1138,7 @@ Example **create constraint** migrations:
11311138

11321139
* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json)
11331140
* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json)
1134-
1141+
* [46_add_table_foreign_key_constraint.json](../examples/46_add_table_foreign_key_constraint.json)
11351142

11361143
### Drop column
11371144

examples/.ledger

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@
4444
44_add_table_unique_constraint.json
4545
45_add_table_check_constraint.json
4646
46_alter_column_drop_default.json
47+
47_add_table_foreign_key_constraint.json
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"name": "47_add_table_foreign_key_constraint",
3+
"operations": [
4+
{
5+
"create_constraint": {
6+
"type": "foreign_key",
7+
"table": "tickets",
8+
"name": "fk_sellers",
9+
"columns": [
10+
"sellers_name",
11+
"sellers_zip"
12+
],
13+
"references": {
14+
"table": "sellers",
15+
"columns": [
16+
"name",
17+
"zip"
18+
]
19+
},
20+
"up": {
21+
"sellers_name": "sellers_name",
22+
"sellers_zip": "sellers_zip"
23+
},
24+
"down": {
25+
"sellers_name": "sellers_name",
26+
"sellers_zip": "sellers_zip"
27+
}
28+
}
29+
}
30+
]
31+
}

pkg/migrations/duplicate.go

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
4444
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
4545
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
46+
cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
4647
)
4748

4849
// NewColumnDuplicator creates a new Duplicator for a column.
@@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
9192
colNames = append(colNames, name)
9293

9394
// Duplicate the column with the new type
94-
// and check and fk constraints
9595
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
9696
_, err := d.conn.ExecContext(ctx, sql)
9797
if err != nil {
@@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
108108
}
109109
}
110110

111+
// Duplicate the column's comment
111112
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
112113
_, err := d.conn.ExecContext(ctx, sql)
113114
if err != nil {
@@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
120121
// if the check constraint is not valid for the new column type, in which case
121122
// the error is ignored.
122123
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
123-
// Update the check constraint expression to use the new column names if any of the columns are duplicated
124124
_, err := d.conn.ExecContext(ctx, sql)
125125
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
126126
if err != nil {
@@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
132132
// The constraint is duplicated by adding a unique index on the column concurrently.
133133
// The index is converted into a unique constraint on migration completion.
134134
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
135-
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
136135
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
137136
return err
138137
}
139138
}
140139

140+
// Generate SQL to duplicate any foreign key constraints on the columns.
141+
// If the foreign key constraint is not valid for a new column type, the error is ignored.
142+
for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) {
143+
_, err := d.conn.ExecContext(ctx, sql)
144+
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
145+
if err != nil {
146+
return err
147+
}
148+
}
149+
141150
return nil
142151
}
143152

@@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
175184
return stmts
176185
}
177186

187+
func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string {
188+
stmts := make([]string, 0, len(d.table.ForeignKeys))
189+
for _, fk := range d.table.ForeignKeys {
190+
if slices.Contains(withoutConstraint, fk.Name) {
191+
continue
192+
}
193+
if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember {
194+
stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL,
195+
pq.QuoteIdentifier(d.table.Name),
196+
pq.QuoteIdentifier(DuplicationName(fk.Name)),
197+
strings.Join(quoteColumnNames(constraintColumns), ", "),
198+
pq.QuoteIdentifier(fk.ReferencedTable),
199+
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
200+
fk.OnDelete,
201+
))
202+
}
203+
}
204+
return stmts
205+
}
206+
178207
// duplicatedConstraintColumns returns a new slice of constraint columns with
179208
// the columns that are duplicated replaced with temporary names.
180209
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
@@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
213242
) string {
214243
const (
215244
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
216-
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
217245
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
218246
)
219247

@@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
232260
)
233261
}
234262

235-
// Generate SQL to duplicate any foreign key constraints on the column
236-
for _, fk := range d.table.ForeignKeys {
237-
if slices.Contains(withoutConstraint, fk.Name) {
238-
continue
239-
}
240-
241-
if slices.Contains(fk.Columns, column.Name) {
242-
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
243-
pq.QuoteIdentifier(DuplicationName(fk.Name)),
244-
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
245-
pq.QuoteIdentifier(fk.ReferencedTable),
246-
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
247-
fk.OnDelete,
248-
)
249-
}
250-
}
251-
252263
return sql
253264
}
254265

@@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
295306
return strings.TrimPrefix(name, "_pgroll_dup_")
296307
}
297308

298-
func copyAndReplace(xs []string, oldValue, newValue string) []string {
299-
ys := slices.Clone(xs)
300-
301-
for i, c := range ys {
302-
if c == oldValue {
303-
ys[i] = newValue
304-
}
305-
}
306-
return ys
307-
}
308-
309309
func errorIgnoringErrorCode(err error, code pq.ErrorCode) error {
310310
pqErr := &pq.Error{}
311311
if ok := errors.As(err, &pqErr); ok {

pkg/migrations/duplicate_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ var table = &schema.Table{
3232
"new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`},
3333
"different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`},
3434
},
35+
ForeignKeys: map[string]schema.ForeignKey{
36+
"fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"},
37+
"fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"},
38+
},
3539
}
3640

3741
func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) {
@@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) {
121125
})
122126
}
123127
}
128+
129+
func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) {
130+
d := &duplicatorStmtBuilder{table}
131+
for name, testCases := range map[string]struct {
132+
columns []string
133+
expectedStmts []string
134+
}{
135+
"duplicate single column with no FK constraint": {
136+
columns: []string{"description"},
137+
expectedStmts: []string{},
138+
},
139+
"single-column FK with single column duplicated": {
140+
columns: []string{"city"},
141+
expectedStmts: []string{
142+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
143+
},
144+
},
145+
"single-column FK with multiple columns duplicated": {
146+
columns: []string{"city", "description"},
147+
expectedStmts: []string{
148+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
149+
},
150+
},
151+
"multi-column FK with single column duplicated": {
152+
columns: []string{"name"},
153+
expectedStmts: []string{
154+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
155+
},
156+
},
157+
"multi-column FK with multiple unrelated column duplicated": {
158+
columns: []string{"name", "description"},
159+
expectedStmts: []string{
160+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
161+
},
162+
},
163+
"multi-column FK with multiple columns": {
164+
columns: []string{"name", "nick"},
165+
expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`},
166+
},
167+
} {
168+
t.Run(name, func(t *testing.T) {
169+
stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...)
170+
assert.Equal(t, len(testCases.expectedStmts), len(stmts))
171+
for _, stmt := range stmts {
172+
assert.Contains(t, testCases.expectedStmts, stmt)
173+
}
174+
})
175+
}
176+
}

pkg/migrations/op_add_column.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) {
285285
sql += fmt.Sprintf(" DEFAULT %s", d)
286286
}
287287
if col.References != nil {
288-
onDelete := "NO ACTION"
288+
onDelete := string(ForeignKeyReferenceOnDeleteNOACTION)
289289
if col.References.OnDelete != "" {
290290
onDelete = strings.ToUpper(string(col.References.OnDelete))
291291
}

pkg/migrations/op_add_column_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ func TestAddForeignKeyColumn(t *testing.T) {
626626
Name: "fk_users_id",
627627
Table: "users",
628628
Column: "id",
629-
OnDelete: "CASCADE",
629+
OnDelete: migrations.ForeignKeyReferenceOnDeleteCASCADE,
630630
},
631631
},
632632
},

pkg/migrations/op_create_constraint.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
7070
return table, o.addUniqueIndex(ctx, conn)
7171
case OpCreateConstraintTypeCheck:
7272
return table, o.addCheckConstraint(ctx, conn)
73+
case OpCreateConstraintTypeForeignKey:
74+
return table, o.addForeignKeyConstraint(ctx, conn)
7375
}
7476

7577
return table, nil
@@ -97,6 +99,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra
9799
if err != nil {
98100
return err
99101
}
102+
case OpCreateConstraintTypeForeignKey:
103+
fkOp := &OpSetForeignKey{
104+
Table: o.Table,
105+
References: ForeignKeyReference{
106+
Name: o.Name,
107+
},
108+
}
109+
err := fkOp.Complete(ctx, conn, tr, s)
110+
if err != nil {
111+
return err
112+
}
100113
}
101114

102115
// remove old columns
@@ -198,6 +211,22 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
198211
if o.Check == nil || *o.Check == "" {
199212
return FieldRequiredError{Name: "check"}
200213
}
214+
case OpCreateConstraintTypeForeignKey:
215+
if o.References == nil {
216+
return FieldRequiredError{Name: "references"}
217+
}
218+
table := s.GetTable(o.References.Table)
219+
if table == nil {
220+
return TableDoesNotExistError{Name: o.References.Table}
221+
}
222+
for _, col := range o.References.Columns {
223+
if table.GetColumn(col) == nil {
224+
return ColumnDoesNotExistError{
225+
Table: o.References.Table,
226+
Name: col,
227+
}
228+
}
229+
}
201230
}
202231

203232
return nil
@@ -223,6 +252,25 @@ func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB)
223252
return err
224253
}
225254

255+
func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB) error {
256+
onDelete := "NO ACTION"
257+
if o.References.OnDelete != "" {
258+
onDelete = strings.ToUpper(string(o.References.OnDelete))
259+
}
260+
261+
_, err := conn.ExecContext(ctx,
262+
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s NOT VALID",
263+
pq.QuoteIdentifier(o.Table),
264+
pq.QuoteIdentifier(o.Name),
265+
strings.Join(quotedTemporaryNames(o.Columns), ","),
266+
pq.QuoteIdentifier(o.References.Table),
267+
strings.Join(quoteColumnNames(o.References.Columns), ","),
268+
onDelete,
269+
))
270+
271+
return err
272+
}
273+
226274
func quotedTemporaryNames(columns []string) []string {
227275
names := make([]string, len(columns))
228276
for i, col := range columns {

0 commit comments

Comments
 (0)