279 lines
7.5 KiB
Go
279 lines
7.5 KiB
Go
/*
|
|
* Amsterdam Web Communities System
|
|
* Copyright (c) 2025-2026 Erbosoft Metaverse Design Solutions, All Rights Reserved
|
|
*
|
|
* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
*
|
|
* SPDX-License-Identifier: MPL-2.0
|
|
*/
|
|
// The database package contains database management and storage logic.
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"regexp"
|
|
"slices"
|
|
"strings"
|
|
|
|
"git.erbosoft.com/amy/amsterdam/config"
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/jmoiron/sqlx"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Error classifications
|
|
const (
|
|
classUnspecified = 0
|
|
classNeedInstall = 1
|
|
classNeedConvert = 2
|
|
)
|
|
|
|
// MySQL Errors
|
|
var errMySQLNoTable = &mysql.MySQLError{Number: 1146}
|
|
var errMySQLNoColumn = &mysql.MySQLError{Number: 1054}
|
|
|
|
//go:embed mysql-install.sql
|
|
var installScriptMySQL string
|
|
|
|
//go:embed mysql-convert.sql
|
|
var convertScriptMySQL string
|
|
|
|
//go:embed mysql-migrate/*
|
|
var migrationsMySQL embed.FS
|
|
|
|
// amdb is the reference to the Amsterdam database.
|
|
var amdb *sqlx.DB
|
|
|
|
// buildMysqlDSN builds the MySQL DSN for the driver.
|
|
func buildMysqlDSN(multiStatement bool) string {
|
|
rc := fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true&loc=UTC",
|
|
config.GlobalComputedConfig.DatabaseUser,
|
|
config.GlobalComputedConfig.DatabasePassword,
|
|
config.GlobalComputedConfig.DatabaseHost,
|
|
config.GlobalComputedConfig.DatabaseName)
|
|
if multiStatement {
|
|
rc += "&multiStatements=true"
|
|
}
|
|
return rc
|
|
}
|
|
|
|
// classifyGetError classifies errors returns from the original get of the version number.
|
|
func classifyGetError(err error) int {
|
|
if errors.Is(err, errMySQLNoTable) {
|
|
return classNeedInstall
|
|
}
|
|
if errors.Is(err, errMySQLNoColumn) {
|
|
return classNeedConvert
|
|
}
|
|
return classUnspecified
|
|
}
|
|
|
|
// databaseVersionNumber reads the version number from the database.
|
|
func databaseVersionNumber(db *sqlx.DB) (string, error) {
|
|
ver := ""
|
|
err := db.Get(&ver, "SELECT version FROM globals")
|
|
return ver, err
|
|
}
|
|
|
|
// setDatabaseVersionNumber resets the version number in the database.
|
|
func setDatabaseVersionNumber(db *sqlx.DB, version string) error {
|
|
_, err := db.Exec("UPDATE globals SET version = ?", version)
|
|
return err
|
|
}
|
|
|
|
// databaseInstallScript returns the install script for the database.
|
|
func databaseInstallScript() (string, error) {
|
|
switch config.GlobalComputedConfig.DatabaseDriver {
|
|
case "mysql":
|
|
return installScriptMySQL, nil
|
|
default:
|
|
return "", fmt.Errorf("No install script for database driver: %s", config.GlobalComputedConfig.DatabaseDriver)
|
|
}
|
|
}
|
|
|
|
// databaseConvertScript returns the script to convert a Venice database to Amsterdam.
|
|
func databaseConvertScript() (string, error) {
|
|
switch config.GlobalComputedConfig.DatabaseDriver {
|
|
case "mysql":
|
|
return convertScriptMySQL, nil
|
|
default: // N.B.: Not to be implemented for any database type besides MySQL!
|
|
return "", fmt.Errorf("No conversion script for database driver: %s", config.GlobalComputedConfig.DatabaseDriver)
|
|
}
|
|
}
|
|
|
|
// databaseMigrationScripts returns the migration scripts to apply to the database.
|
|
func databaseMigrationScripts(version string) (fs.FS, string, []string, error) {
|
|
var myfs fs.FS
|
|
var dirname string = ""
|
|
var err error = nil
|
|
switch config.GlobalComputedConfig.DatabaseDriver {
|
|
case "mysql":
|
|
myfs = migrationsMySQL
|
|
dirname = "mysql-migrate"
|
|
default:
|
|
err = fmt.Errorf("No migration scripts for database driver: %s", config.GlobalComputedConfig.DatabaseDriver)
|
|
}
|
|
if err != nil {
|
|
return nil, "", make([]string, 0), err
|
|
}
|
|
rdfs := myfs.(fs.ReadDirFS)
|
|
dents, err := rdfs.ReadDir(dirname)
|
|
if err != nil {
|
|
return nil, "", make([]string, 0), err
|
|
}
|
|
rc := make([]string, 0, len(dents))
|
|
for _, d := range dents {
|
|
s := strings.TrimSuffix(d.Name(), ".sql")
|
|
m, err := regexp.Match(`\d{10}`, []byte(s))
|
|
if err != nil {
|
|
return nil, "", make([]string, 0), err
|
|
}
|
|
if m && s > version {
|
|
rc = append(rc, d.Name())
|
|
}
|
|
}
|
|
if len(rc) > 1 {
|
|
slices.Sort(rc)
|
|
}
|
|
return myfs, dirname, rc, nil
|
|
}
|
|
|
|
// prepareDB prepares the database if it's not yet been loaded.
|
|
func prepareDB() (string, error) {
|
|
dsn := buildMysqlDSN(true)
|
|
log.Debugf("dsn=%s", dsn)
|
|
db, err := sqlx.Connect(config.GlobalComputedConfig.DatabaseDriver, dsn)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer db.Close()
|
|
version, err := databaseVersionNumber(db)
|
|
if err != nil {
|
|
switch classifyGetError(err) {
|
|
case classUnspecified:
|
|
log.Errorf("*** cannot get version number: %v (%T)", err, err)
|
|
return version, err
|
|
case classNeedInstall:
|
|
installScript, err := databaseInstallScript()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, err = db.Exec(installScript)
|
|
if err != nil {
|
|
return "", fmt.Errorf("Failure of install script: %w", err)
|
|
}
|
|
case classNeedConvert:
|
|
convertScript, err := databaseConvertScript()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, err = db.Exec(convertScript)
|
|
if err != nil {
|
|
return "", fmt.Errorf("Failure of conversion script: %w", err)
|
|
}
|
|
}
|
|
version, err = databaseVersionNumber(db)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
scriptfs, dirname, scripts, err := databaseMigrationScripts(version)
|
|
if err == nil {
|
|
log.Infof("%d migration script(s) to be applied", len(scripts))
|
|
rffs := scriptfs.(fs.ReadFileFS)
|
|
for _, script := range scripts {
|
|
log.Infof("applying migration script: %s", script)
|
|
var data []byte
|
|
data, err = rffs.ReadFile(fmt.Sprintf("%s/%s", dirname, script))
|
|
if err != nil {
|
|
return version, fmt.Errorf("Unable to read migration script %s: %w", script, err)
|
|
}
|
|
_, err = db.Exec(string(data))
|
|
if err != nil {
|
|
return version, fmt.Errorf("Unable to apply migration script %s: %w", script, err)
|
|
}
|
|
nv := strings.TrimSuffix(script, ".sql")
|
|
err = setDatabaseVersionNumber(db, nv)
|
|
if err != nil {
|
|
break
|
|
}
|
|
version = nv
|
|
}
|
|
}
|
|
return version, err
|
|
}
|
|
|
|
// SetupDb sets up the database and associated items.
|
|
func SetupDb() (func(), error) {
|
|
exitfns := make([]func(), 0, 2)
|
|
version, err := prepareDB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db, err := sqlx.Connect(config.GlobalComputedConfig.DatabaseDriver, buildMysqlDSN(false))
|
|
if err == nil {
|
|
amdb = db
|
|
g, err := AmGlobals(context.Background())
|
|
if err == nil {
|
|
if g.Version != version {
|
|
log.Warnf("!! database version %s does not match prepared version %s", g.Version, version)
|
|
}
|
|
setupAdCache()
|
|
setupUserCache()
|
|
setupContactsCache()
|
|
setupCommunityCache()
|
|
setupServicesCache()
|
|
setupConferenceCache()
|
|
exitfns = append(exitfns, setupAuditWriter())
|
|
exitfns = append(exitfns, setupIPBanSweep())
|
|
log.Infof("SetupDb(): database version %s", g.Version)
|
|
}
|
|
}
|
|
slices.Reverse(exitfns)
|
|
return func() {
|
|
for _, f := range exitfns {
|
|
f()
|
|
}
|
|
amdb.Close()
|
|
}, err
|
|
}
|
|
|
|
/* transaction starts a transaction and returns functions for commit and rollback. The rollback
|
|
* function can be immediately deferred; if commit is called successfully, rollback becomes a no-op.
|
|
* Parameters:
|
|
* ctx - Standard Go error status.
|
|
* Returns:
|
|
* The sqlx transaction object
|
|
* The commit function (no parameters, returns error)
|
|
* The rollback function (no parameters or return)
|
|
*/
|
|
func transaction(ctx context.Context) (*sqlx.Tx, func() error, func()) {
|
|
tx := amdb.MustBeginTx(ctx, nil)
|
|
live := true
|
|
fCom := func() error {
|
|
var err error = nil
|
|
if live {
|
|
err = tx.Commit()
|
|
if err == nil {
|
|
live = false
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
fRoll := func() {
|
|
if live {
|
|
if err := tx.Rollback(); err != nil {
|
|
log.Errorf("***ROLLBACK ERROR*** %v", err)
|
|
}
|
|
live = false
|
|
}
|
|
}
|
|
return tx, fCom, fRoll
|
|
}
|