From 41e72a787dbce71a0629c64587f669700ca3a125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9mi=20V=C3=A1nyi?= Date: Thu, 19 Jun 2025 18:40:22 +0200 Subject: [PATCH] Move TriggerConfig to backfill package --- pkg/backfill/trigger.go | 37 +++++++++++++++++ pkg/migrations/op_add_column.go | 10 ++--- pkg/migrations/op_add_column_test.go | 41 ++++++++++--------- pkg/migrations/op_alter_column.go | 18 ++++---- pkg/migrations/op_change_type_test.go | 9 ++-- pkg/migrations/op_common_test.go | 8 ++-- pkg/migrations/op_create_constraint.go | 16 ++++---- pkg/migrations/op_drop_column.go | 10 ++--- pkg/migrations/op_drop_column_test.go | 9 ++-- pkg/migrations/op_drop_constraint.go | 18 ++++---- .../op_drop_multicolumn_constraint.go | 16 ++++---- pkg/migrations/op_set_check_test.go | 9 ++-- pkg/migrations/trigger.go | 41 +++---------------- pkg/migrations/trigger_test.go | 18 ++++---- 14 files changed, 135 insertions(+), 125 deletions(-) create mode 100644 pkg/backfill/trigger.go diff --git a/pkg/backfill/trigger.go b/pkg/backfill/trigger.go new file mode 100644 index 000000000..eadc80027 --- /dev/null +++ b/pkg/backfill/trigger.go @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 + +package backfill + +import ( + "github.com/xataio/pgroll/pkg/schema" +) + +type TriggerDirection string + +const ( + TriggerDirectionUp TriggerDirection = "up" + TriggerDirectionDown TriggerDirection = "down" +) + +type TriggerConfig struct { + Name string + Direction TriggerDirection + Columns map[string]*schema.Column + SchemaName string + TableName string + PhysicalColumn string + LatestSchema string + SQL string + NeedsBackfillColumn string +} + +// TriggerFunctionName returns the name of the trigger function +// for a given table and column. +func TriggerFunctionName(tableName, columnName string) string { + return "_pgroll_trigger_" + tableName + "_" + columnName +} + +// TriggerName returns the name of the trigger for a given table and column. +func TriggerName(tableName, columnName string) string { + return TriggerFunctionName(tableName, columnName) +} diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 483bf77bb..9a5949ccc 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -103,9 +103,9 @@ func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSch var tableToBackfill *schema.Table if o.Up != "" { err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, o.Column.Name), - Direction: TriggerDirectionUp, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, o.Column.Name), + Direction: backfill.TriggerDirectionUp, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -149,7 +149,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *sch return err } - err = NewDropFunctionAction(conn, TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx) + err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx) if err != nil { return err } @@ -224,7 +224,7 @@ func (o *OpAddColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *sch return err } - err = NewDropFunctionAction(conn, TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx) + err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx) if err != nil { return err } diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 6269ead1e..b7050a05b 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" ) @@ -858,11 +859,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -874,11 +875,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, res) // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, }, @@ -940,11 +941,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -956,11 +957,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, res) // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, }, @@ -1022,11 +1023,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -1038,11 +1039,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, res) // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, }, @@ -1106,11 +1107,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -1122,11 +1123,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, res) // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, }, @@ -1197,11 +1198,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -1214,11 +1215,11 @@ func TestAddColumnWithUpSql(t *testing.T) { }, res) // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("products", "description") + triggerFnName := backfill.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("products", "description") + triggerName := backfill.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) }, }, diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 1a2900add..a77a95b23 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -40,9 +40,9 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, o.Column), - Direction: TriggerDirectionUp, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, o.Column), + Direction: backfill.TriggerDirectionUp, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -66,9 +66,9 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS // Add a trigger to copy values from the new column to the old. err = NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, TemporaryName(o.Column)), - Direction: TriggerDirectionDown, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, TemporaryName(o.Column)), + Direction: backfill.TriggerDirectionDown, Columns: table.Columns, LatestSchema: latestSchema, SchemaName: s.Name, @@ -114,7 +114,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *s } // Remove the up and down function and trigger - err = NewDropFunctionAction(conn, TriggerFunctionName(o.Table, o.Column), TriggerFunctionName(o.Table, TemporaryName(o.Column))).Execute(ctx) + err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column), backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column))).Execute(ctx) if err != nil { return err } @@ -170,8 +170,8 @@ func (o *OpAlterColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *s // Remove the up and down functions and triggers if err := NewDropFunctionAction( conn, - TriggerFunctionName(o.Table, o.Column), - TriggerFunctionName(o.Table, TemporaryName(o.Column)), + backfill.TriggerFunctionName(o.Table, o.Column), + backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column)), ).Execute(ctx); err != nil { return err } diff --git a/pkg/migrations/op_change_type_test.go b/pkg/migrations/op_change_type_test.go index 124a02d15..7142d1caa 100644 --- a/pkg/migrations/op_change_type_test.go +++ b/pkg/migrations/op_change_type_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -131,14 +132,14 @@ func TestChangeColumnType(t *testing.T) { }, rows) // The up function no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", "rating")) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName("reviews", "rating")) // The down function no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName("reviews", migrations.TemporaryName("rating"))) // The up trigger no longer exists. - TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", "rating")) + TriggerMustNotExist(t, db, schema, "reviews", backfill.TriggerName("reviews", "rating")) // The down trigger no longer exists. - TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating"))) + TriggerMustNotExist(t, db, schema, "reviews", backfill.TriggerName("reviews", migrations.TemporaryName("rating"))) }, }, { diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index cedeaf3c5..0a266fbf2 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -981,14 +981,14 @@ func TableMustBeCleanedUp(t *testing.T, db *sql.DB, schema, table string, column ColumnMustNotExist(t, db, schema, table, backfill.CNeedsBackfillColumn) // The up function for the column no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName(table, column)) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName(table, column)) // The down function for the column no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName(table, migrations.TemporaryName(column))) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName(table, migrations.TemporaryName(column))) // The up trigger for the column no longer exists. - TriggerMustNotExist(t, db, schema, table, migrations.TriggerName(table, column)) + TriggerMustNotExist(t, db, schema, table, backfill.TriggerName(table, column)) // The down trigger for the column no longer exists. - TriggerMustNotExist(t, db, schema, table, migrations.TriggerName(table, migrations.TemporaryName(column))) + TriggerMustNotExist(t, db, schema, table, backfill.TriggerName(table, migrations.TemporaryName(column))) } } diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index 05d94600c..fb4cca0dd 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -45,9 +45,9 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la for _, colName := range o.Columns { upSQL := o.Up[colName] err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, colName), - Direction: TriggerDirectionUp, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, colName), + Direction: backfill.TriggerDirectionUp, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -71,9 +71,9 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la downSQL := o.Down[colName] err = NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, TemporaryName(colName)), - Direction: TriggerDirectionDown, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, TemporaryName(colName)), + Direction: backfill.TriggerDirectionDown, Columns: table.Columns, LatestSchema: latestSchema, SchemaName: s.Name, @@ -205,8 +205,8 @@ func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, func (o *OpCreateConstraint) removeTriggers(ctx context.Context, conn db.DB) error { dropFuncs := make([]string, 0, len(o.Columns)*2) for _, column := range o.Columns { - dropFuncs = append(dropFuncs, TriggerFunctionName(o.Table, column)) - dropFuncs = append(dropFuncs, TriggerFunctionName(o.Table, TemporaryName(column))) + dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, column)) + dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, TemporaryName(column))) } return NewDropFunctionAction(conn, dropFuncs...).Execute(ctx) } diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index ee1e9d823..c0be0179a 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -20,9 +20,9 @@ func (o *OpDropColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSc if o.Down != "" { err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, o.Column), - Direction: TriggerDirectionDown, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, o.Column), + Direction: backfill.TriggerDirectionDown, Columns: s.GetTable(o.Table).Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -58,7 +58,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *sc return err } - err = NewDropFunctionAction(conn, TriggerFunctionName(o.Table, o.Column)).Execute(ctx) + err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)).Execute(ctx) if err != nil { return err } @@ -77,7 +77,7 @@ func (o *OpDropColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *sc table := s.GetTable(o.Table) - err := NewDropFunctionAction(conn, TriggerFunctionName(o.Table, o.Column)).Execute(ctx) + err := NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)).Execute(ctx) if err != nil { return err } diff --git a/pkg/migrations/op_drop_column_test.go b/pkg/migrations/op_drop_column_test.go index 022cac40c..dd0a8ffc7 100644 --- a/pkg/migrations/op_drop_column_test.go +++ b/pkg/migrations/op_drop_column_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -78,11 +79,11 @@ func TestDropColumnWithDownSQL(t *testing.T) { }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("users", "name") + triggerFnName := backfill.TriggerFunctionName("users", "name") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("users", "name") + triggerName := backfill.TriggerName("users", "name") TriggerMustNotExist(t, db, schema, "users", triggerName) }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { @@ -90,11 +91,11 @@ func TestDropColumnWithDownSQL(t *testing.T) { ColumnMustNotExist(t, db, schema, "users", "name") // The trigger function has been dropped. - triggerFnName := migrations.TriggerFunctionName("users", "name") + triggerFnName := backfill.TriggerFunctionName("users", "name") FunctionMustNotExist(t, db, schema, triggerFnName) // The trigger has been dropped. - triggerName := migrations.TriggerName("users", "name") + triggerName := backfill.TriggerName("users", "name") TriggerMustNotExist(t, db, schema, "users", triggerName) // Inserting into the view in the new version schema should succeed. diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index 77ea39d80..69df0b3f6 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -38,9 +38,9 @@ func (o *OpDropConstraint) Start(ctx context.Context, l Logger, conn db.DB, late // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, column.Name), - Direction: TriggerDirectionUp, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, column.Name), + Direction: backfill.TriggerDirectionUp, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -62,9 +62,9 @@ func (o *OpDropConstraint) Start(ctx context.Context, l Logger, conn db.DB, late // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. err = NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, TemporaryName(column.Name)), - Direction: TriggerDirectionDown, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, TemporaryName(column.Name)), + Direction: backfill.TriggerDirectionDown, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -87,7 +87,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s column := table.GetColumn(table.GetConstraintColumns(o.Name)[0]) // Remove the up and down function and trigger - err := NewDropFunctionAction(conn, TriggerFunctionName(o.Table, column.Name), TriggerFunctionName(o.Table, TemporaryName(column.Name))).Execute(ctx) + err := NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, column.Name), backfill.TriggerFunctionName(o.Table, TemporaryName(column.Name))).Execute(ctx) if err != nil { return err } @@ -134,8 +134,8 @@ func (o *OpDropConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s // Remove the up and down functions and triggers if err := NewDropFunctionAction( conn, - TriggerFunctionName(o.Table, columnName), - TriggerFunctionName(o.Table, TemporaryName(columnName)), + backfill.TriggerFunctionName(o.Table, columnName), + backfill.TriggerFunctionName(o.Table, TemporaryName(columnName)), ).Execute(ctx); err != nil { return err } diff --git a/pkg/migrations/op_drop_multicolumn_constraint.go b/pkg/migrations/op_drop_multicolumn_constraint.go index 7984fc2b5..b07eeb47b 100644 --- a/pkg/migrations/op_drop_multicolumn_constraint.go +++ b/pkg/migrations/op_drop_multicolumn_constraint.go @@ -52,9 +52,9 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, l Logger, conn for _, columnName := range table.GetConstraintColumns(o.Name) { // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. err := NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, columnName), - Direction: TriggerDirectionUp, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, columnName), + Direction: backfill.TriggerDirectionUp, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -78,9 +78,9 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, l Logger, conn // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. err = NewCreateTriggerAction(conn, - triggerConfig{ - Name: TriggerName(o.Table, TemporaryName(columnName)), - Direction: TriggerDirectionDown, + backfill.TriggerConfig{ + Name: backfill.TriggerName(o.Table, TemporaryName(columnName)), + Direction: backfill.TriggerDirectionDown, Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, @@ -104,7 +104,7 @@ func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, l Logger, co for _, columnName := range table.GetConstraintColumns(o.Name) { // Remove the up and down function and trigger - err := NewDropFunctionAction(conn, TriggerFunctionName(o.Table, columnName), TriggerFunctionName(o.Table, TemporaryName(columnName))).Execute(ctx) + err := NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, columnName), backfill.TriggerFunctionName(o.Table, TemporaryName(columnName))).Execute(ctx) if err != nil { return err } @@ -148,7 +148,7 @@ func (o *OpDropMultiColumnConstraint) Rollback(ctx context.Context, l Logger, co } // Remove the up and down function and trigger - err = NewDropFunctionAction(conn, TriggerFunctionName(o.Table, columnName), TriggerFunctionName(o.Table, TemporaryName(columnName))).Execute(ctx) + err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, columnName), backfill.TriggerFunctionName(o.Table, TemporaryName(columnName))).Execute(ctx) if err != nil { return err } diff --git a/pkg/migrations/op_set_check_test.go b/pkg/migrations/op_set_check_test.go index f2d8a0305..5de7e77f9 100644 --- a/pkg/migrations/op_set_check_test.go +++ b/pkg/migrations/op_set_check_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" ) @@ -128,14 +129,14 @@ func TestSetCheckConstraint(t *testing.T) { }, rows) // The up function no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("posts", "title")) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName("posts", "title")) // The down function no longer exists. - FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("posts", migrations.TemporaryName("title"))) + FunctionMustNotExist(t, db, schema, backfill.TriggerFunctionName("posts", migrations.TemporaryName("title"))) // The up trigger no longer exists. - TriggerMustNotExist(t, db, schema, "posts", migrations.TriggerName("posts", "title")) + TriggerMustNotExist(t, db, schema, "posts", backfill.TriggerName("posts", "title")) // The down trigger no longer exists. - TriggerMustNotExist(t, db, schema, "posts", migrations.TriggerName("posts", migrations.TemporaryName("title"))) + TriggerMustNotExist(t, db, schema, "posts", backfill.TriggerName("posts", migrations.TemporaryName("title"))) }, }, { diff --git a/pkg/migrations/trigger.go b/pkg/migrations/trigger.go index b096f1e46..e64b7cbfa 100644 --- a/pkg/migrations/trigger.go +++ b/pkg/migrations/trigger.go @@ -14,34 +14,14 @@ import ( "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/migrations/templates" - "github.com/xataio/pgroll/pkg/schema" ) -type TriggerDirection string - -const ( - TriggerDirectionUp TriggerDirection = "up" - TriggerDirectionDown TriggerDirection = "down" -) - -type triggerConfig struct { - Name string - Direction TriggerDirection - Columns map[string]*schema.Column - SchemaName string - TableName string - PhysicalColumn string - LatestSchema string - SQL string - NeedsBackfillColumn string -} - type createTriggerAction struct { conn db.DB - cfg triggerConfig + cfg backfill.TriggerConfig } -func NewCreateTriggerAction(conn db.DB, cfg triggerConfig) DBAction { +func NewCreateTriggerAction(conn db.DB, cfg backfill.TriggerConfig) DBAction { return &createTriggerAction{ conn: conn, cfg: cfg, @@ -85,26 +65,15 @@ func (a *createTriggerAction) Execute(ctx context.Context) error { }) } -func buildFunction(cfg triggerConfig) (string, error) { +func buildFunction(cfg backfill.TriggerConfig) (string, error) { return executeTemplate("function", templates.Function, cfg) } -func buildTrigger(cfg triggerConfig) (string, error) { +func buildTrigger(cfg backfill.TriggerConfig) (string, error) { return executeTemplate("trigger", templates.Trigger, cfg) } -// TriggerFunctionName returns the name of the trigger function -// for a given table and column. -func TriggerFunctionName(tableName, columnName string) string { - return "_pgroll_trigger_" + tableName + "_" + columnName -} - -// TriggerName returns the name of the trigger for a given table and column. -func TriggerName(tableName, columnName string) string { - return TriggerFunctionName(tableName, columnName) -} - -func executeTemplate(name, content string, cfg triggerConfig) (string, error) { +func executeTemplate(name, content string, cfg backfill.TriggerConfig) (string, error) { tmpl := template.Must(template. New(name). Funcs(template.FuncMap{ diff --git a/pkg/migrations/trigger_test.go b/pkg/migrations/trigger_test.go index bd063e06c..727357f9c 100644 --- a/pkg/migrations/trigger_test.go +++ b/pkg/migrations/trigger_test.go @@ -13,14 +13,14 @@ import ( func TestBuildFunction(t *testing.T) { testCases := []struct { name string - config triggerConfig + config backfill.TriggerConfig expected string }{ { name: "simple up trigger", - config: triggerConfig{ + config: backfill.TriggerConfig{ Name: "triggerName", - Direction: TriggerDirectionUp, + Direction: backfill.TriggerDirectionUp, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -61,9 +61,9 @@ func TestBuildFunction(t *testing.T) { }, { name: "simple down trigger", - config: triggerConfig{ + config: backfill.TriggerConfig{ Name: "triggerName", - Direction: TriggerDirectionDown, + Direction: backfill.TriggerDirectionDown, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -104,9 +104,9 @@ func TestBuildFunction(t *testing.T) { }, { name: "down trigger with aliased column", - config: triggerConfig{ + config: backfill.TriggerConfig{ Name: "triggerName", - Direction: TriggerDirectionDown, + Direction: backfill.TriggerDirectionDown, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -163,12 +163,12 @@ func TestBuildFunction(t *testing.T) { func TestBuildTrigger(t *testing.T) { testCases := []struct { name string - config triggerConfig + config backfill.TriggerConfig expected string }{ { name: "trigger", - config: triggerConfig{ + config: backfill.TriggerConfig{ Name: "triggerName", TableName: "reviews", },