Skip to content

Commit 12ae369

Browse files
authored
Add support for composite keys in create table statements (#413)
This PR adds support for setting a composite key for a table. From now on it is possible to set `pk` to `true` in multiple columns in `create_table`. The create table statement is translated to the following format: ```sql CREATE TABLE my_table ( id SERIAL, code VARCHAR(255), count INTEGER, PRIMARY KEY (id, code) ); ```
1 parent 515dd54 commit 12ae369

File tree

5 files changed

+191
-38
lines changed

5 files changed

+191
-38
lines changed

examples/01_create_tables.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@
4242
}
4343
]
4444
}
45+
},
46+
{
47+
"create_table": {
48+
"name": "sellers",
49+
"columns": [
50+
{
51+
"name": "name",
52+
"type": "varchar(255)",
53+
"pk": true
54+
},
55+
{
56+
"name": "zip",
57+
"type": "integer",
58+
"pk": true
59+
},
60+
{
61+
"name": "description",
62+
"type": "varchar(255)",
63+
"nullable": true
64+
}
65+
]
66+
}
4567
}
4668
]
4769
}

pkg/migrations/op_add_column.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
203203
o.Column.Check = nil
204204

205205
o.Column.Name = TemporaryName(o.Column.Name)
206-
colSQL, err := ColumnToSQL(o.Column, tr)
206+
columnWriter := ColumnSQLWriter{WithPK: true, Transformer: tr}
207+
colSQL, err := columnWriter.Write(o.Column)
207208
if err != nil {
208209
return err
209210
}
@@ -243,3 +244,51 @@ func NotNullConstraintName(columnName string) string {
243244
func IsNotNullConstraintName(name string) bool {
244245
return strings.HasPrefix(name, "_pgroll_check_not_null_")
245246
}
247+
248+
// ColumnSQLWriter writes a column to SQL
249+
// It can optionally include the primary key constraint
250+
// When creating a table, the primary key constraint is not added to the column definition
251+
type ColumnSQLWriter struct {
252+
WithPK bool
253+
Transformer SQLTransformer
254+
}
255+
256+
func (w ColumnSQLWriter) Write(col Column) (string, error) {
257+
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
258+
259+
if w.WithPK && col.IsPrimaryKey() {
260+
sql += " PRIMARY KEY"
261+
}
262+
263+
if col.IsUnique() {
264+
sql += " UNIQUE"
265+
}
266+
if !col.IsNullable() {
267+
sql += " NOT NULL"
268+
}
269+
if col.Default != nil {
270+
d, err := w.Transformer.TransformSQL(*col.Default)
271+
if err != nil {
272+
return "", err
273+
}
274+
sql += fmt.Sprintf(" DEFAULT %s", d)
275+
}
276+
if col.References != nil {
277+
onDelete := "NO ACTION"
278+
if col.References.OnDelete != "" {
279+
onDelete = strings.ToUpper(string(col.References.OnDelete))
280+
}
281+
282+
sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
283+
pq.QuoteIdentifier(col.References.Name),
284+
pq.QuoteIdentifier(col.References.Table),
285+
pq.QuoteIdentifier(col.References.Column),
286+
onDelete)
287+
}
288+
if col.Check != nil {
289+
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
290+
pq.QuoteIdentifier(col.Check.Name),
291+
col.Check.Constraint)
292+
}
293+
return sql, nil
294+
}

pkg/migrations/op_common_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ func ColumnMustNotHaveComment(t *testing.T, db *sql.DB, schema, table, column st
178178
}
179179
}
180180

181+
func ColumnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) {
182+
t.Helper()
183+
if !columnMustBePK(t, db, schema, table, column) {
184+
t.Fatalf("Expected column %q to be primary key", column)
185+
}
186+
}
187+
181188
func TableMustHaveComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) {
182189
t.Helper()
183190
if !tableHasComment(t, db, schema, table, expectedComment) {
@@ -526,6 +533,28 @@ func columnHasComment(t *testing.T, db *sql.DB, schema, table, column string, ex
526533
return actualComment != nil && *expectedComment == *actualComment
527534
}
528535

536+
func columnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) bool {
537+
t.Helper()
538+
539+
var exists bool
540+
err := db.QueryRow(fmt.Sprintf(`
541+
SELECT EXISTS (
542+
SELECT a.attname
543+
FROM pg_index i
544+
JOIN pg_attribute a ON a.attrelid = i.indrelid
545+
AND a.attnum = ANY(i.indkey)
546+
WHERE i.indrelid = %[1]s::regclass AND i.indisprimary AND a.attname = %[2]s
547+
)`,
548+
pq.QuoteLiteral(fmt.Sprintf("%s.%s", schema, table)),
549+
pq.QuoteLiteral(column)),
550+
).Scan(&exists)
551+
if err != nil {
552+
t.Fatal(err)
553+
}
554+
555+
return exists
556+
}
557+
529558
func tableHasComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) bool {
530559
t.Helper()
531560

pkg/migrations/op_create_table.go

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -113,55 +113,25 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {
113113

114114
func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) {
115115
var sql string
116+
var primaryKeys []string
117+
columnWriter := ColumnSQLWriter{WithPK: false, Transformer: tr}
116118
for i, col := range cols {
117119
if i > 0 {
118120
sql += ", "
119121
}
120-
colSQL, err := ColumnToSQL(col, tr)
122+
colSQL, err := columnWriter.Write(col)
121123
if err != nil {
122124
return "", err
123125
}
124126
sql += colSQL
125-
}
126-
return sql, nil
127-
}
128127

129-
// ColumnToSQL generates the SQL for a column definition.
130-
func ColumnToSQL(col Column, tr SQLTransformer) (string, error) {
131-
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
132-
133-
if col.IsPrimaryKey() {
134-
sql += " PRIMARY KEY"
135-
}
136-
if col.IsUnique() {
137-
sql += " UNIQUE"
138-
}
139-
if !col.IsNullable() {
140-
sql += " NOT NULL"
141-
}
142-
if col.Default != nil {
143-
d, err := tr.TransformSQL(*col.Default)
144-
if err != nil {
145-
return "", err
128+
if col.IsPrimaryKey() {
129+
primaryKeys = append(primaryKeys, pq.QuoteIdentifier(col.Name))
146130
}
147-
sql += fmt.Sprintf(" DEFAULT %s", d)
148131
}
149-
if col.References != nil {
150-
onDelete := "NO ACTION"
151-
if col.References.OnDelete != "" {
152-
onDelete = strings.ToUpper(string(col.References.OnDelete))
153-
}
154132

155-
sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
156-
pq.QuoteIdentifier(col.References.Name),
157-
pq.QuoteIdentifier(col.References.Table),
158-
pq.QuoteIdentifier(col.References.Column),
159-
onDelete)
160-
}
161-
if col.Check != nil {
162-
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
163-
pq.QuoteIdentifier(col.Check.Name),
164-
col.Check.Constraint)
133+
if len(primaryKeys) > 0 {
134+
sql += fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))
165135
}
166136
return sql, nil
167137
}

pkg/migrations/op_create_table_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,89 @@ func TestCreateTable(t *testing.T) {
7676
}, rows)
7777
},
7878
},
79+
{
80+
name: "create table with composite key",
81+
migrations: []migrations.Migration{
82+
{
83+
Name: "01_create_table",
84+
Operations: migrations.Operations{
85+
&migrations.OpCreateTable{
86+
Name: "users",
87+
Columns: []migrations.Column{
88+
{
89+
Name: "id",
90+
Type: "serial",
91+
Pk: ptr(true),
92+
},
93+
{
94+
Name: "rand",
95+
Type: "varchar(255)",
96+
Pk: ptr(true),
97+
},
98+
{
99+
Name: "name",
100+
Type: "varchar(255)",
101+
Unique: ptr(true),
102+
},
103+
},
104+
},
105+
},
106+
},
107+
},
108+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
109+
// The new view exists in the new version schema.
110+
ViewMustExist(t, db, schema, "01_create_table", "users")
111+
112+
// Data can be inserted into the new view.
113+
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
114+
"rand": "123",
115+
"name": "Alice",
116+
})
117+
// New record with same keys cannot be inserted.
118+
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
119+
"id": "1",
120+
"rand": "123",
121+
"name": "Malice",
122+
}, testutils.UniqueViolationErrorCode)
123+
124+
// Data can be retrieved from the new view.
125+
rows := MustSelect(t, db, schema, "01_create_table", "users")
126+
assert.Equal(t, []map[string]any{
127+
{"id": 1, "rand": "123", "name": "Alice"},
128+
}, rows)
129+
},
130+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
131+
// The underlying table has been dropped.
132+
TableMustNotExist(t, db, schema, "users")
133+
},
134+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
135+
// The view still exists
136+
ViewMustExist(t, db, schema, "01_create_table", "users")
137+
138+
// The columns are still primary keys.
139+
ColumnMustBePK(t, db, schema, "users", "id")
140+
ColumnMustBePK(t, db, schema, "users", "rand")
141+
142+
// Data can be inserted into the new view.
143+
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
144+
"rand": "123",
145+
"name": "Alice",
146+
})
147+
148+
// New record with same keys cannot be inserted.
149+
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
150+
"id": "1",
151+
"rand": "123",
152+
"name": "Malice",
153+
}, testutils.UniqueViolationErrorCode)
154+
155+
// Data can be retrieved from the new view.
156+
rows := MustSelect(t, db, schema, "01_create_table", "users")
157+
assert.Equal(t, []map[string]any{
158+
{"id": 1, "rand": "123", "name": "Alice"},
159+
}, rows)
160+
},
161+
},
79162
{
80163
name: "create table with foreign key with default ON DELETE NO ACTION",
81164
migrations: []migrations.Migration{

0 commit comments

Comments
 (0)