Skip to content

Commit 747f9d4

Browse files
authored
Fix previous_version: Take inferred migrations into account (#631)
Add `includeInferred` parameter to function `previous_version`. If it's `false`, it will behave exactly as it did before this PR. It will be called with `false` in `Complete`, to ignore inferred migrations when dropping the version schema, because inferred migrations do not have version schemas. If it's `true`, `inferred` migrations will be taken into account when finding the previous version. It will be called with `true` when finding the previous version for `Rollback`. When inferred migrations were not taken into account, `Rollback` was causing a segmentation violation because of nil pointer for new columns/tables added by SQL. This PR also modifies the test `TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected` so that we will be testing `PreviousVersion` with both options. Fixes: #623
1 parent cf50739 commit 747f9d4

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

pkg/roll/execute.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func (m *Roll) Complete(ctx context.Context) error {
155155

156156
// Drop the old schema
157157
if !m.disableVersionSchemas && (!migration.ContainsRawSQLOperation() || !m.noVersionSchemaForRawSQL) {
158-
prevVersion, err := m.state.PreviousVersion(ctx, m.schema)
158+
prevVersion, err := m.state.PreviousVersion(ctx, m.schema, false)
159159
if err != nil {
160160
return fmt.Errorf("unable to get name of previous version: %w", err)
161161
}
@@ -241,7 +241,7 @@ func (m *Roll) Rollback(ctx context.Context) error {
241241
}
242242

243243
// get the name of the previous version of the schema
244-
previousVersion, err := m.state.PreviousVersion(ctx, m.schema)
244+
previousVersion, err := m.state.PreviousVersion(ctx, m.schema, true)
245245
if err != nil {
246246
return fmt.Errorf("unable to get name of previous version: %w", err)
247247
}

pkg/roll/execute_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,20 @@ func TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected(t *testing.T) {
185185
err = mig.Start(ctx, &migrations.Migration{Name: "03_create_table", Operations: migrations.Operations{createTableOp("table3")}})
186186
require.NoError(t, err)
187187

188-
// The previous version is migration 01 because there is no version schema
189-
// for migration 02 due to the `WithNoVersionSchemaForRawSQL` option
190-
prevVersion, err := st.PreviousVersion(ctx, "public")
188+
// The previous version is migration 01 if raw SQL migrations are ignored
189+
prevVersion, err := st.PreviousVersion(ctx, "public", false)
191190
require.NoError(t, err)
192191
require.NotNil(t, prevVersion)
193192
assert.Equal(t, "01_create_table", *prevVersion)
194193

194+
// The previous version is migration 02 (inferred) but there is no version schema
195+
// for migration 02 due to the `WithNoVersionSchemaForRawSQL` option
196+
prevVersion, err = st.PreviousVersion(ctx, "public", true)
197+
require.NoError(t, err)
198+
require.NotNil(t, prevVersion)
199+
assert.Equal(t, "02_create_table", *prevVersion)
200+
assert.False(t, schemaExists(t, db, roll.VersionedSchemaName(schema, "02_create_table")))
201+
195202
// Complete the third migration
196203
err = mig.Complete(ctx)
197204
require.NoError(t, err)

pkg/state/init.sql

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ STABLE;
7676
-- Get the name of the previous version of the schema, or NULL if there is none.
7777
-- This ignores previous versions for which no version schema exists, such as
7878
-- versions corresponding to inferred migrations.
79-
CREATE OR REPLACE FUNCTION placeholder.previous_version (schemaname name)
79+
CREATE OR REPLACE FUNCTION placeholder.previous_version (schemaname name, includeInferred boolean)
8080
RETURNS text
8181
AS $$
8282
WITH RECURSIVE ancestors AS (
@@ -107,10 +107,17 @@ CREATE OR REPLACE FUNCTION placeholder.previous_version (schemaname name)
107107
a.name
108108
FROM
109109
ancestors a
110-
JOIN information_schema.schemata s ON s.schema_name = schemaname || '_' || a.name
111110
WHERE
112-
migration_type = 'pgroll'
113-
AND a.depth > 0
111+
a.depth > 0
112+
AND (includeInferred
113+
OR (a.migration_type = 'pgroll'
114+
AND EXISTS (
115+
SELECT
116+
s.schema_name
117+
FROM
118+
information_schema.schemata s
119+
WHERE
120+
s.schema_name = schemaname || '_' || a.name)))
114121
ORDER BY
115122
a.depth ASC
116123
LIMIT 1;

pkg/state/state.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ func (s *State) LatestVersion(ctx context.Context, schema string) (*string, erro
136136
}
137137

138138
// PreviousVersion returns the name of the previous version schema
139-
func (s *State) PreviousVersion(ctx context.Context, schema string) (*string, error) {
139+
func (s *State) PreviousVersion(ctx context.Context, schema string, includeInferred bool) (*string, error) {
140140
var parent *string
141141
err := s.pgConn.QueryRowContext(ctx,
142-
fmt.Sprintf("SELECT %s.previous_version($1)", pq.QuoteIdentifier(s.schema)),
143-
schema).Scan(&parent)
142+
fmt.Sprintf("SELECT %s.previous_version($1, $2)", pq.QuoteIdentifier(s.schema)),
143+
schema, includeInferred).Scan(&parent)
144144
if err != nil {
145145
return nil, err
146146
}

0 commit comments

Comments
 (0)