Skip to content

Commit 8a05c25

Browse files
committed
Introduce database.AutoUpgradeSchema()
1 parent 0d0428f commit 8a05c25

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

database/schema.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package database
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
goErrors "errors"
7+
"fmt"
8+
"github.com/pkg/errors"
9+
)
10+
11+
// SchemaUpgrade represents a single upgrade step.
12+
type SchemaUpgrade struct {
13+
// Version specifies the target version as in the schema table.
14+
Version string
15+
// DDL aggregates one or more .sql files' contents.
16+
DDL []string
17+
}
18+
19+
// SchemaData summaries all available DDL for a database type.
20+
type SchemaData struct {
21+
// Schema aggregates one or more .sql files' contents.
22+
Schema []string
23+
// Upgrades aggregates all available upgrade steps in ascending order.
24+
Upgrades []SchemaUpgrade
25+
}
26+
27+
var ErrDbTypeNotUpgradable = goErrors.New("no schema supplied for given database type")
28+
29+
// AutoUpgradeSchema imports or upgrades the schema in db from schemaData.
30+
func AutoUpgradeSchema(
31+
ctx context.Context, db *DB, dbName, schemaTable, schemaTableVersionColumn, schemaTableTimestampColumn string,
32+
schemaData map[string]SchemaData,
33+
) error {
34+
ourSchema, driverSupported := schemaData[db.DriverName()]
35+
if !driverSupported {
36+
return errors.Wrap(ErrDbTypeNotUpgradable, "can't upgrade schema")
37+
}
38+
39+
err := db.QueryRowContext(
40+
ctx, db.Rebind("SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME=?"),
41+
dbName, schemaTable,
42+
).Scan(new(int8))
43+
44+
switch err {
45+
case nil:
46+
var currentVersion string
47+
48+
err := db.QueryRowContext(ctx, fmt.Sprintf(
49+
"SELECT %s FROM %s ORDER BY %s DESC LIMIT 1",
50+
schemaTableVersionColumn, schemaTable, schemaTableTimestampColumn,
51+
)).Scan(&currentVersion)
52+
if err != nil {
53+
return errors.Wrap(err, "can't check schema version")
54+
}
55+
56+
// If there's no upgrade step to the current version, it must be the first one, so apply all upgrades.
57+
upgrades := ourSchema.Upgrades
58+
59+
for i, upgrade := range ourSchema.Upgrades {
60+
if upgrade.Version == currentVersion {
61+
// If there's an upgrade step to the current version, apply all subsequent ones.
62+
upgrades = ourSchema.Upgrades[i+1:]
63+
break
64+
}
65+
}
66+
67+
for _, upgrade := range upgrades {
68+
if err := importSchema(ctx, db, upgrade.DDL); err != nil {
69+
return errors.Wrap(err, "can't upgrade schema")
70+
}
71+
}
72+
73+
return nil
74+
case sql.ErrNoRows:
75+
return errors.Wrap(importSchema(ctx, db, ourSchema.Schema), "can't import schema")
76+
default:
77+
return errors.Wrap(err, "can't check schema existence")
78+
}
79+
}
80+
81+
// importSchema imports one or more .sql files' contents from schema into db.
82+
func importSchema(ctx context.Context, db *DB, schema []string) error {
83+
for _, ddls := range schema {
84+
for _, ddl := range MysqlSplitStatements(ddls) {
85+
if _, err := db.ExecContext(ctx, ddl); err != nil {
86+
return err
87+
}
88+
}
89+
}
90+
91+
return nil
92+
}

0 commit comments

Comments
 (0)