Skip to content

Introduce database.AutoUpgradeSchema() #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions database/schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package database

import (
"context"
"database/sql"
goErrors "errors"
"fmt"
"github.com/pkg/errors"
)

// SchemaUpgrade represents a single upgrade step.
type SchemaUpgrade struct {
// Version specifies the target version as in the schema table.
Version string
// DDL aggregates one or more .sql files' contents.
DDL []string
}

// SchemaData summaries all available DDL for a database type.
type SchemaData struct {
// Schema aggregates one or more .sql files' contents.
Schema []string
// Upgrades aggregates all available upgrade steps in ascending order.
Upgrades []SchemaUpgrade
}

var ErrDbTypeNotUpgradable = goErrors.New("no schema supplied for given database type")

// AutoUpgradeSchema imports or upgrades the schema in db from schemaData.
func AutoUpgradeSchema(
ctx context.Context, db *DB, dbName, schemaTable, schemaTableVersionColumn, schemaTableTimestampColumn string,
schemaData map[string]SchemaData,
) error {
ourSchema, driverSupported := schemaData[db.DriverName()]
if !driverSupported {
return errors.Wrap(ErrDbTypeNotUpgradable, "can't upgrade schema")
}

err := db.QueryRowContext(
ctx, db.Rebind("SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME=?"),
dbName, schemaTable,
).Scan(new(int8))

switch err {
case nil:
var currentVersion string

err := db.QueryRowContext(ctx, fmt.Sprintf(
"SELECT %s FROM %s ORDER BY %s DESC LIMIT 1",
schemaTableVersionColumn, schemaTable, schemaTableTimestampColumn,
)).Scan(&currentVersion)
if err != nil {
return errors.Wrap(err, "can't check schema version")
}

// If there's no upgrade step to the current version, it must be the first one, so apply all upgrades.
upgrades := ourSchema.Upgrades

for i, upgrade := range ourSchema.Upgrades {
if upgrade.Version == currentVersion {
// If there's an upgrade step to the current version, apply all subsequent ones.
upgrades = ourSchema.Upgrades[i+1:]
break
}
}

for _, upgrade := range upgrades {
if err := importSchema(ctx, db, upgrade.DDL); err != nil {
return errors.Wrap(err, "can't upgrade schema")
}
}

return nil
case sql.ErrNoRows:
return errors.Wrap(importSchema(ctx, db, ourSchema.Schema), "can't import schema")
default:
return errors.Wrap(err, "can't check schema existence")
}
}

// importSchema imports one or more .sql files' contents from schema into db.
func importSchema(ctx context.Context, db *DB, schema []string) error {
for _, ddls := range schema {
for _, ddl := range MysqlSplitStatements(ddls) {
if _, err := db.ExecContext(ctx, ddl); err != nil {
return err
}
}
}

return nil
}
Loading