diff --git a/database/base.go b/database/base.go index 9c6403e..e5df5a7 100644 --- a/database/base.go +++ b/database/base.go @@ -13,10 +13,13 @@ package database import ( "context" - _ "embed" + "embed" "errors" "fmt" + "io/fs" + "regexp" "slices" + "strings" "git.erbosoft.com/amy/amsterdam/config" "github.com/go-sql-driver/mysql" @@ -41,6 +44,9 @@ var installScriptMySQL string //go:embed mysql-convert.sql var convertScriptMySQL string +//go:embed mysql-migrate/* +var migrationsMySQL embed.FS + // amdb is the reference to the Amsterdam database. var amdb *sqlx.DB @@ -75,6 +81,12 @@ func databaseVersionNumber(db *sqlx.DB) (string, error) { return ver, err } +// setDatabaseVersionNumber resets the version number in the database. +func setDatabaseVersionNumber(db *sqlx.DB, version string) error { + _, err := db.Exec("UPDATE globals SET version = ?", version) + return err +} + // databaseInstallScript returns the install script for the database. func databaseInstallScript() (string, error) { switch config.GlobalComputedConfig.DatabaseDriver { @@ -95,6 +107,41 @@ func databaseConvertScript() (string, error) { } } +// databaseMigrationScripts returns the migration scripts to apply to the database. +func databaseMigrationScripts(version string) (fs.FS, []string, error) { + var myfs fs.FS + var err error + switch config.GlobalComputedConfig.DatabaseDriver { + case "mysql": + myfs, err = fs.Sub(migrationsMySQL, "mysql-migrate") + default: + err = fmt.Errorf("No migration scripts for database driver: %s", config.GlobalComputedConfig.DatabaseDriver) + } + if err != nil { + return nil, make([]string, 0), err + } + rdfs := myfs.(fs.ReadDirFS) + dents, err := rdfs.ReadDir("/") + if err != nil { + return nil, make([]string, 0), err + } + rc := make([]string, 0, len(dents)) + for _, d := range dents { + s := strings.TrimSuffix(d.Name(), ".sql") + m, err := regexp.Match(`\d{10}`, []byte(s)) + if err != nil { + return nil, make([]string, 0), err + } + if m && s > version { + rc = append(rc, d.Name()) + } + } + if len(rc) > 1 { + slices.Sort(rc) + } + return myfs, rc, nil +} + // prepareDB prepares the database if it's not yet been loaded. func prepareDB() (string, error) { dsn := buildMysqlDSN(true) @@ -134,7 +181,27 @@ func prepareDB() (string, error) { return "", err } } - // TODO: apply migration scripts + scriptfs, scripts, err := databaseMigrationScripts(version) + if err == nil { + log.Infof("%d migration script(s) to be applied", len(scripts)) + rffs := scriptfs.(fs.ReadFileFS) + for _, script := range scripts { + log.Infof("applying migration script: %s", script) + var data []byte + data, err = rffs.ReadFile(script) + if err != nil { + return version, fmt.Errorf("Unable to read migration script %s: %w", script, err) + } + _, err = db.Exec(string(data)) + if err != nil { + return version, fmt.Errorf("Unable to apply migration script %s: %w", script, err) + } + err = setDatabaseVersionNumber(db, strings.TrimSuffix(script, ".sql")) + if err != nil { + break + } + } + } return version, err } diff --git a/database/mysql-migrate/README.txt b/database/mysql-migrate/README.txt new file mode 100644 index 0000000..cf60da9 --- /dev/null +++ b/database/mysql-migrate/README.txt @@ -0,0 +1,2 @@ +Migration scripts for MySQL go in this directory. Name them as YYYYMMDDNN.sql, where YYYYMMDD is the +current date and NN is a two-digit sequence number beginning at 01. \ No newline at end of file