Skip to content

Commit c756988

Browse files
Convert SET DATA TYPE SQL to pgroll operation (#506)
Convert SQL statements of the form: ```sql ALTER TABLE foo ALTER COLUMN a [SET DATA] TYPE text ``` to the equivalent `pgroll` migration: ```json [ { "alter_column": { "column": "a", "down": "TODO: Implement SQL data migration", "table": "foo", "type": "text", "up": "TODO: Implement SQL data migration" } } ] ``` Part of #504
1 parent 84661eb commit c756988

File tree

5 files changed

+102
-44
lines changed

5 files changed

+102
-44
lines changed

pkg/sql2pgroll/alter_table.go

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
package sql2pgroll
44

55
import (
6+
"fmt"
7+
68
pgq "github.com/pganalyze/pg_query_go/v6"
79
"github.com/xataio/pgroll/pkg/migrations"
810
)
@@ -22,25 +24,50 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
2224
continue
2325
}
2426

25-
switch alterTableCmd.Subtype {
27+
var op migrations.Operation
28+
var err error
29+
switch alterTableCmd.GetSubtype() {
2630
case pgq.AlterTableType_AT_SetNotNull:
27-
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true))
31+
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, true)
2832
case pgq.AlterTableType_AT_DropNotNull:
29-
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false))
33+
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
34+
case pgq.AlterTableType_AT_AlterColumnType:
35+
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
36+
}
37+
38+
if err != nil {
39+
return nil, err
3040
}
41+
42+
ops = append(ops, op)
3143
}
3244

3345
return ops, nil
3446
}
3547

36-
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation {
48+
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
3749
return &migrations.OpAlterColumn{
3850
Table: stmt.GetRelation().GetRelname(),
3951
Column: cmd.GetName(),
4052
Nullable: ptr(!notNull),
4153
Up: PlaceHolderSQL,
4254
Down: PlaceHolderSQL,
55+
}, nil
56+
}
57+
58+
func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
59+
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
60+
if !ok {
61+
return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node)
4362
}
63+
64+
return &migrations.OpAlterColumn{
65+
Table: stmt.GetRelation().GetRelname(),
66+
Column: cmd.GetName(),
67+
Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())),
68+
Up: PlaceHolderSQL,
69+
Down: PlaceHolderSQL,
70+
}, nil
4471
}
4572

4673
func ptr[T any](x T) *T {

pkg/sql2pgroll/alter_table_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
2727
sql: "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL",
2828
expectedOp: expect.AlterTableOp2,
2929
},
30+
{
31+
sql: "ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text",
32+
expectedOp: expect.AlterTableOp3,
33+
},
34+
{
35+
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
36+
expectedOp: expect.AlterTableOp3,
37+
},
3038
}
3139

3240
for _, tc := range tests {

pkg/sql2pgroll/create_table.go

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
package sql2pgroll
44

55
import (
6-
"fmt"
7-
"strings"
8-
96
pgq "github.com/pganalyze/pg_query_go/v6"
107
"github.com/xataio/pgroll/pkg/migrations"
118
)
@@ -26,42 +23,8 @@ func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
2623
}
2724

2825
func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
29-
ignoredTypeParts := map[string]bool{
30-
"pg_catalog": true,
31-
}
32-
33-
// Build the type name, including any schema qualifiers
34-
typeParts := make([]string, 0, len(col.GetTypeName().Names))
35-
for _, node := range col.GetTypeName().Names {
36-
typePart := node.GetString_().GetSval()
37-
if _, ok := ignoredTypeParts[typePart]; ok {
38-
continue
39-
}
40-
typeParts = append(typeParts, typePart)
41-
}
42-
43-
// Build the type modifiers, such as precision and scale for numeric types
44-
var typeMods []string
45-
for _, node := range col.GetTypeName().Typmods {
46-
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
47-
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
48-
}
49-
}
50-
var typeModifier string
51-
if len(typeMods) > 0 {
52-
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
53-
}
54-
55-
// Build the array bounds for array types
56-
var arrayBounds string
57-
for _, node := range col.GetTypeName().ArrayBounds {
58-
bound := node.GetInteger().GetIval()
59-
if bound == -1 {
60-
arrayBounds = "[]"
61-
} else {
62-
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
63-
}
64-
}
26+
// Convert the column type
27+
typeString := convertTypeName(col.TypeName)
6528

6629
// Determine column nullability, uniqueness, and primary key status
6730
var notNull, unique, pk bool
@@ -81,7 +44,7 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
8144

8245
return migrations.Column{
8346
Name: col.Colname,
84-
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
47+
Type: typeString,
8548
Nullable: !notNull,
8649
Unique: unique,
8750
Default: defaultValue,

pkg/sql2pgroll/expect/alter_table.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ var AlterTableOp2 = &migrations.OpAlterColumn{
2323
Down: sql2pgroll.PlaceHolderSQL,
2424
}
2525

26+
var AlterTableOp3 = &migrations.OpAlterColumn{
27+
Table: "foo",
28+
Column: "a",
29+
Type: ptr("text"),
30+
Up: sql2pgroll.PlaceHolderSQL,
31+
Down: sql2pgroll.PlaceHolderSQL,
32+
}
33+
2634
func ptr[T any](v T) *T {
2735
return &v
2836
}

pkg/sql2pgroll/typename.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll
4+
5+
import (
6+
"fmt"
7+
"strings"
8+
9+
pgq "github.com/pganalyze/pg_query_go/v6"
10+
)
11+
12+
// convertTypeName converts a TypeName node to a string.
13+
func convertTypeName(typeName *pgq.TypeName) string {
14+
ignoredTypeParts := map[string]bool{
15+
"pg_catalog": true,
16+
}
17+
18+
// Build the type name, including any schema qualifiers
19+
typeParts := make([]string, 0, len(typeName.Names))
20+
for _, node := range typeName.Names {
21+
typePart := node.GetString_().GetSval()
22+
if _, ok := ignoredTypeParts[typePart]; ok {
23+
continue
24+
}
25+
typeParts = append(typeParts, typePart)
26+
}
27+
28+
// Build the type modifiers, such as precision and scale for numeric types
29+
var typeMods []string
30+
for _, node := range typeName.Typmods {
31+
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
32+
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
33+
}
34+
}
35+
var typeModifier string
36+
if len(typeMods) > 0 {
37+
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
38+
}
39+
40+
// Build the array bounds for array types
41+
var arrayBounds string
42+
for _, node := range typeName.ArrayBounds {
43+
bound := node.GetInteger().GetIval()
44+
if bound == -1 {
45+
arrayBounds = "[]"
46+
} else {
47+
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
48+
}
49+
}
50+
51+
return strings.Join(typeParts, ".") + typeModifier + arrayBounds
52+
}

0 commit comments

Comments
 (0)