Skip to content

Commit db8123b

Browse files
Add support for alter_table set NOT NULL operations with rename_table operations (#606)
Ensure that multi-operation migrations combining `alter_column` `SET NOT NULL` and `rename_table` operations work as expected. ```json { "name": "06_multi_operation", "operations": [ { "rename_table": { "from": "items", "to": "products" } }, { "alter_column": { "table": "products", "column": "name", "nullable": false, "up": "SELECT CASE WHEN name IS NULL THEN 'anonymous' ELSE name END", "down": "name || '_from_down_trigger'" } } ] } ``` This migration renames a table and then sets a column's `nullability` on the renamed table. Previously the migration would fail as the `alter_column` operation was unaware of the changes made by the preceding operation. Part of #239
1 parent 3657bb5 commit db8123b

File tree

5 files changed

+125
-6
lines changed

5 files changed

+125
-6
lines changed

pkg/migrations/op_alter_column.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri
3939
Columns: table.Columns,
4040
SchemaName: s.Name,
4141
LatestSchema: latestSchema,
42-
TableName: o.Table,
42+
TableName: table.Name,
4343
PhysicalColumn: TemporaryName(o.Column),
4444
SQL: o.upSQLForOperations(ops),
4545
})
@@ -61,7 +61,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri
6161
Columns: table.Columns,
6262
LatestSchema: latestSchema,
6363
SchemaName: s.Name,
64-
TableName: o.Table,
64+
TableName: table.Name,
6565
PhysicalColumn: o.Column,
6666
SQL: o.downSQLForOperations(ops),
6767
})
@@ -115,9 +115,10 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransfor
115115
}
116116

117117
func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
118-
ops := o.subOperations()
118+
table := s.GetTable(o.Table)
119119

120120
// Perform any operation specific rollback steps
121+
ops := o.subOperations()
121122
for _, ops := range ops {
122123
if err := ops.Rollback(ctx, conn, tr, nil); err != nil {
123124
return err
@@ -126,7 +127,7 @@ func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransfor
126127

127128
// Drop the new column
128129
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
129-
pq.QuoteIdentifier(o.Table),
130+
pq.QuoteIdentifier(table.Name),
130131
pq.QuoteIdentifier(TemporaryName(o.Column)),
131132
))
132133
if err != nil {

pkg/migrations/op_create_table.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, latestSchema stri
5959
}
6060

6161
func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
62-
// No-op
62+
// Update the in-memory schema representation with the new table
63+
o.updateSchema(s)
64+
6365
return nil
6466
}
6567

pkg/migrations/op_rename_table.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransfor
2222
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
2323
pq.QuoteIdentifier(o.From),
2424
pq.QuoteIdentifier(o.To)))
25+
26+
// Rename the table in the virtual schema so that the `Complete` methods
27+
// of subsequent operations in the same migration can find it.
28+
s.RenameTable(o.From, o.To)
29+
30+
// Update the physical name of the table in the virtual schema now that it
31+
// has really been renamed.
32+
table := s.GetTable(o.To)
33+
table.Name = o.To
34+
2535
return err
2636
}
2737

pkg/migrations/op_set_notnull.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, latestSchema strin
2424
table := s.GetTable(o.Table)
2525

2626
// Add an unchecked NOT NULL constraint to the new column.
27-
if err := addNotNullConstraint(ctx, conn, o.Table, o.Column, TemporaryName(o.Column)); err != nil {
27+
if err := addNotNullConstraint(ctx, conn, table.Name, o.Column, TemporaryName(o.Column)); err != nil {
2828
return nil, fmt.Errorf("failed to add not null constraint: %w", err)
2929
}
3030

pkg/migrations/op_set_notnull_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,112 @@ func TestSetNotNull(t *testing.T) {
531531
})
532532
}
533533

534+
func TestSetNotNullInMultiOperationMigrations(t *testing.T) {
535+
t.Parallel()
536+
537+
ExecuteTests(t, TestCases{
538+
{
539+
name: "rename table, set not null",
540+
migrations: []migrations.Migration{
541+
{
542+
Name: "01_create_table",
543+
Operations: migrations.Operations{
544+
&migrations.OpCreateTable{
545+
Name: "items",
546+
Columns: []migrations.Column{
547+
{
548+
Name: "id",
549+
Type: "int",
550+
Pk: true,
551+
},
552+
{
553+
Name: "name",
554+
Type: "varchar(255)",
555+
Nullable: true,
556+
},
557+
},
558+
},
559+
},
560+
},
561+
{
562+
Name: "02_multi_operation",
563+
Operations: migrations.Operations{
564+
&migrations.OpRenameTable{
565+
From: "items",
566+
To: "products",
567+
},
568+
&migrations.OpAlterColumn{
569+
Table: "products",
570+
Column: "name",
571+
Nullable: ptr(false),
572+
Up: "SELECT CASE WHEN name IS NULL THEN 'unknown' ELSE name END",
573+
Down: "name",
574+
},
575+
},
576+
},
577+
},
578+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
579+
// Can insert a row into the new schema that meets the constraint
580+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
581+
"id": "1",
582+
"name": "apple",
583+
})
584+
585+
// Can't insert a row into the new schema that violates the constraint
586+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
587+
"id": "2",
588+
}, testutils.CheckViolationErrorCode)
589+
590+
// Can insert a row into the old schema that violates the constraint
591+
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
592+
"id": "2",
593+
})
594+
595+
// The new view has the expected rows
596+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
597+
assert.Equal(t, []map[string]any{
598+
{"id": 1, "name": "apple"},
599+
{"id": 2, "name": "unknown"},
600+
}, rows)
601+
602+
// The old view has the expected rows
603+
rows = MustSelect(t, db, schema, "01_create_table", "items")
604+
assert.Equal(t, []map[string]any{
605+
{"id": 1, "name": "apple"},
606+
{"id": 2, "name": nil},
607+
}, rows)
608+
},
609+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
610+
// The table has been cleaned up
611+
TableMustBeCleanedUp(t, db, schema, "items", "name")
612+
},
613+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
614+
// Can insert a row into the new schema that meets the constraint
615+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
616+
"id": "3",
617+
"name": "carrot",
618+
})
619+
620+
// Can't insert a row into the new schema that violates the constraint
621+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
622+
"id": "3",
623+
}, testutils.NotNullViolationErrorCode)
624+
625+
// The new view has the expected rows
626+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
627+
assert.Equal(t, []map[string]any{
628+
{"id": 1, "name": "apple"},
629+
{"id": 2, "name": "unknown"},
630+
{"id": 3, "name": "carrot"},
631+
}, rows)
632+
633+
// The table has been cleaned up
634+
TableMustBeCleanedUp(t, db, schema, "products", "name")
635+
},
636+
},
637+
})
638+
}
639+
534640
func TestSetNotNullValidation(t *testing.T) {
535641
t.Parallel()
536642

0 commit comments

Comments
 (0)