From a2c2a1f75058cbc61c16e424401a9b63e3656ada Mon Sep 17 00:00:00 2001 From: Amy Gale Ruth Bowersox Date: Wed, 6 May 2026 22:19:08 -0600 Subject: [PATCH] cleanups to startup code and goroutine code --- database/base.go | 17 ++++++++------ email/sender.go | 3 +++ htmlcheck/dict_trie.go | 16 +++++-------- htmlcheck/dictionary.go | 10 ++++---- logging.go | 36 +++++++++++++++++------------ main.go | 31 +++++++++++++++++-------- ui/amsession.go | 51 +++++++++++++++++++---------------------- ui/dialog.go | 4 ++-- util/util.go | 12 ---------- 9 files changed, 92 insertions(+), 88 deletions(-) diff --git a/database/base.go b/database/base.go index b172e65..605fb9d 100644 --- a/database/base.go +++ b/database/base.go @@ -29,9 +29,9 @@ import ( // Error classifications const ( - classUnspecified = 0 - classNeedInstall = 1 - classNeedConvert = 2 + classUnspecified = iota // unspecified, barf + classNeedInstall // need to install the database + classNeedConvert // need to convert a Venice database ) // MySQL Errors @@ -210,11 +210,11 @@ func prepareDB() (string, error) { } // SetupDb sets up the database and associated items. -func SetupDb() (func(), error) { +func SetupDb() (string, func(), error) { exitfns := make([]func(), 0, 2) version, err := prepareDB() if err != nil { - return nil, err + return "X", nil, err } db, err := sqlx.Connect(config.GlobalComputedConfig.DatabaseDriver, buildMysqlDSN(false)) if err == nil { @@ -223,6 +223,7 @@ func SetupDb() (func(), error) { if err == nil { if g.Version != version { log.Warnf("!! database version %s does not match prepared version %s", g.Version, version) + version = g.Version } setupAdCache() setupUserCache() @@ -232,11 +233,11 @@ func SetupDb() (func(), error) { setupConferenceCache() exitfns = append(exitfns, setupAuditWriter()) exitfns = append(exitfns, setupIPBanSweep()) - log.Infof("SetupDb(): database version %s", g.Version) + log.Infof("SetupDb(): database version %s", version) } } slices.Reverse(exitfns) - return func() { + return version, func() { for _, f := range exitfns { f() } @@ -262,6 +263,8 @@ func transaction(ctx context.Context) (*sqlx.Tx, func() error, func()) { err = tx.Commit() if err == nil { live = false + } else { + log.Errorf("***COMMIT ERROR*** %v", err) } } return err diff --git a/email/sender.go b/email/sender.go index 1db6798..63a6eb0 100644 --- a/email/sender.go +++ b/email/sender.go @@ -234,6 +234,9 @@ func SetupMailSender() func() { emailRenderer.AddGlobal("AmsterdamVersion", config.AMSTERDAM_VERSION) emailRenderer.AddGlobal("AmsterdamCopyright", config.AMSTERDAM_COPYRIGHT) emailRenderer.AddGlobal("GlobalConfig", config.GlobalConfig) + emailRenderer.AddGlobal("PLSCOPE_COMMUNITY", database.PLSCOPE_COMMUNITY) + emailRenderer.AddGlobal("PLSCOPE_CONFERENCE", database.PLSCOPE_CONFERENCE) + emailRenderer.AddGlobal("PLSCOPE_TOPIC", database.PLSCOPE_TOPIC) // Start the recycler. messageRecycleBin = make(chan *amMessage, config.GlobalConfig.Tuning.Queues.EmailRecycle) diff --git a/htmlcheck/dict_trie.go b/htmlcheck/dict_trie.go index 104792d..f9cfbfd 100644 --- a/htmlcheck/dict_trie.go +++ b/htmlcheck/dict_trie.go @@ -1,6 +1,6 @@ /* * Amsterdam Web Communities System - * Copyright (c) 2025 Erbosoft Metaverse Design Solutions, All Rights Reserved + * Copyright (c) 2025-2026 Erbosoft Metaverse Design Solutions, All Rights Reserved * * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this @@ -44,17 +44,17 @@ func (d *TrieDictionary) Size() int { // CheckWord returns true if a word is in the dictionary, false if not. func (d *TrieDictionary) CheckWord(word string) bool { d.mutex.Lock() - defer d.mutex.Unlock() _, rc := d.trie.Find(strings.ToLower(word)) + d.mutex.Unlock() return rc } // AddWord adds a new word to the dictionary. func (d *TrieDictionary) AddWord(word string) { d.mutex.Lock() - defer d.mutex.Unlock() d.trie.Add(strings.ToLower(word), true) d.count++ + d.mutex.Unlock() } // DelWord deletes a word from the dictionary. @@ -89,12 +89,8 @@ func loadDict(d *TrieDictionary, words []byte) { // LoadTrieDict creates a TrieDictionary from a byte array that represents a word list (one word per line). func LoadTrieDict(words []byte) *TrieDictionary { - rc := TrieDictionary{ - loaded: atomic.Bool{}, - trie: trie.New(), - count: 0, - } + rc := new(TrieDictionary{loaded: atomic.Bool{}, trie: trie.New(), count: 0}) rc.loaded.Store(false) - go loadDict(&rc, words) - return &rc + go loadDict(rc, words) + return rc } diff --git a/htmlcheck/dictionary.go b/htmlcheck/dictionary.go index 8a96740..2954ddb 100644 --- a/htmlcheck/dictionary.go +++ b/htmlcheck/dictionary.go @@ -55,10 +55,8 @@ func SetupDicts() { log.Errorf("failed to load external dictionary %s: %v", config.GlobalConfig.Posting.ExternalDictionary, err) } } - rw := spellingRewriter{ - dict: NewCompositeDict(dicts), - } - rewriterRegistry[rw.Name()] = &rw + rw := new(spellingRewriter{dict: NewCompositeDict(dicts)}) + rewriterRegistry[rw.Name()] = rw } // spellingRewriter is a rewriter that flags spelling errors. @@ -89,10 +87,10 @@ func (rw *spellingRewriter) Rewrite(ctx context.Context, data string, svc rewrit if rw.dict.CheckWord(data) { return nil } - return &markupData{ + return new(markupData{ beginMarkup: defaultBeginError, text: data, endMarkup: defaultEndError, rescan: false, - } + }) } diff --git a/logging.go b/logging.go index c1f165a..ae8c058 100644 --- a/logging.go +++ b/logging.go @@ -31,6 +31,12 @@ import ( log "github.com/sirupsen/logrus" ) +// DEFAULT_MAXLOG is the default maximum log file size (16 megabytes). +const DEFAULT_MAXLOG = 16 * 1024 * 1024 + +// LOG_ROTATE_INTERVAL is the interval, in seconds, at which we try to rotate the logfile. +const LOG_ROTATE_INTERVAL = 10 + /*---------------------------------------------------------------------------- * slog handler that outputs to Logrus *---------------------------------------------------------------------------- @@ -52,11 +58,7 @@ type SlogLogrusHandler struct { // NewSlogLogrusHandler creates a SlogLogrusHandler with base information. func NewSlogLogrusHandler() *SlogLogrusHandler { - rc := new(SlogLogrusHandler{ - fields: make(log.Fields), - groupPrefix: "", - }) - return rc + return new(SlogLogrusHandler{fields: make(log.Fields), groupPrefix: ""}) } // Enabled returns true if the specified log level is handled. @@ -81,20 +83,18 @@ func (h *SlogLogrusHandler) Handle(ctx context.Context, r slog.Record) error { // WithAttrs creates a new Handler from this one, with extra attributes. func (h *SlogLogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - newh := new(SlogLogrusHandler{fields: make(log.Fields)}) + newh := new(SlogLogrusHandler{fields: make(log.Fields), groupPrefix: h.groupPrefix}) maps.Copy(newh.fields, h.fields) for _, a := range attrs { newh.fields[a.Key] = a.Value.Any() } - newh.groupPrefix = h.groupPrefix return newh } // WithGroup creates a new Handler from this one, with an extra group prefix. func (h *SlogLogrusHandler) WithGroup(name string) slog.Handler { - newh := new(SlogLogrusHandler{fields: make(log.Fields)}) + newh := new(SlogLogrusHandler{fields: make(log.Fields), groupPrefix: h.groupPrefix + name + "."}) maps.Copy(newh.fields, h.fields) - newh.groupPrefix = h.groupPrefix + name + "." return newh } @@ -160,6 +160,7 @@ func (lf *amLogFile) Close() error { } // rotate closes the log file and moves it to a new name, shuffling the previously stored log files by the same amount. +// N.B.: We must be holding lf.mutex. func (lf *amLogFile) rotate() error { if lf.keep == 0 && lf.keepCompressed == 0 { return nil // degenerate case, keep the log file the same @@ -262,7 +263,9 @@ func (lf *amLogFile) tryRotate() { if lf.curSize >= lf.maxSize { err := lf.rotate() if err != nil { - //log.Error("log rotation failed") + log.SetOutput(os.Stderr) + log.Errorf("log rotation failed: %v", err) + log.SetOutput(lf) } } lf.mutex.Unlock() @@ -302,8 +305,7 @@ func (lf *amLogFile) open(path string) error { // logScanner is a goroutine that monitors the log file to see when it needs rotating. func logScanner(ctx context.Context, lf *amLogFile, done chan bool) { - d, _ := time.ParseDuration("10s") - t := time.NewTicker(d) + t := time.NewTicker(LOG_ROTATE_INTERVAL * time.Second) for { select { case <-ctx.Done(): @@ -319,8 +321,10 @@ func logScanner(ctx context.Context, lf *amLogFile, done chan bool) { // SetupLogging sets up the log file based on the configuration data. func SetupLogging() func() { loglevel, err := log.ParseLevel(config.GlobalComputedConfig.LogLevel) - if err != nil { + if err == nil { loglevel = log.ErrorLevel + } else { + log.Errorf("default log level not valid: %s (%v)", config.GlobalComputedConfig.LogLevel, err) } if config.GlobalComputedConfig.DebugMode && loglevel != log.TraceLevel { loglevel = log.DebugLevel @@ -333,7 +337,8 @@ func SetupLogging() func() { amlog := new(amLogFile) maxlog, err := humanize.ParseBytes(config.GlobalConfig.Logging.MaxLogSize) if err != nil { - maxlog = 16 * 1024 * 1024 // default to 16 megabytes + log.Errorf("invalid value for max log size: %s (%v)", config.GlobalConfig.Logging.MaxLogSize, err) + maxlog = DEFAULT_MAXLOG } amlog.maxSize = int64(maxlog) amlog.keep = config.GlobalConfig.Logging.KeepLogFiles @@ -344,13 +349,14 @@ func SetupLogging() func() { ctx, cancelfunc = context.WithCancel(context.Background()) done = make(chan bool) go logScanner(ctx, amlog, done) + } else { + log.Errorf("**** failed to open amlog: %v - logs will go to stdout", err) } } if logfile == nil { log.SetOutput(os.Stdout) } else { log.SetOutput(logfile) - } log.SetLevel(loglevel) diff --git a/main.go b/main.go index 650cc9a..d036b55 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "log/slog" + "net" "net/http" "os" "os/signal" @@ -40,9 +41,23 @@ import ( // READ_HEADER_TIMEOUT is the timeout value for reading headers in seconds. (Deliberately NOT configurable because this is a security issue) const READ_HEADER_TIMEOUT = 2 +// GRACEFUL_SHUTDOWN_TIMEOUT is the timeout value for a graceful shutdown. +const GRACEFUL_SHUTDOWN_TIMEOUT = 10 * time.Second + // GetAndPost is used to have functions that respond to both GET and POST on a URI. var GetAndPost = []string{http.MethodGet, http.MethodPost} +// myIPAddress returns the IP address of this computer. +func myIPAddress() net.IP { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + panic(err) + } + defer conn.Close() + localAddr := conn.LocalAddr().(*net.UDPAddr) + return localAddr.IP +} + // setupEcho creates, configures, and returns a new Echo instance. func setupEcho() *echo.Echo { e := echo.New() @@ -225,11 +240,15 @@ var SystemStartTime time.Time // main is Ye Olde Main Function. func main() { SystemStartTime = time.Now() + + // Determine my IP address. + myIP := myIPAddress() + // Configure the system. config.SetupConfig() closer := SetupLogging() defer closer() - closer, err := database.SetupDb() + dbVersion, closer, err := database.SetupDb() if err != nil { panic(fmt.Sprintf("Database open failure: %v", err)) } @@ -240,12 +259,6 @@ func main() { closer = ui.SetupUILayer() defer closer() - // Determine my IP address and the admin user. - myIP, err := util.MyIPAddress() - if err != nil { - panic(err) - } - // Set up to trap SIGINT/SIGTERM and shut down gracefully ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -262,7 +275,7 @@ func main() { // Audit the startup database.AmStoreAudit(database.AmNewAudit(database.AuditStartup, 0, myIP.String(), - fmt.Sprintf("version=%s", config.AMSTERDAM_VERSION))) + fmt.Sprintf("version=%s", config.AMSTERDAM_VERSION), fmt.Sprintf("database=%s", dbVersion))) defer func() { // Audit the shutdown database.AmStoreAudit(database.AmNewAudit(database.AuditShutdown, 0, myIP.String())) @@ -273,7 +286,7 @@ func main() { Address: config.GlobalComputedConfig.Listen, HideBanner: true, HidePort: true, - GracefulTimeout: 10 * time.Second, + GracefulTimeout: GRACEFUL_SHUTDOWN_TIMEOUT, OnShutdownError: func(err error) { log.Fatalf("error in shutting down the server: %v", err) }, diff --git a/ui/amsession.go b/ui/amsession.go index 86c8a8b..d9954d0 100644 --- a/ui/amsession.go +++ b/ui/amsession.go @@ -18,7 +18,6 @@ import ( "net/http" "slices" "sync" - "sync/atomic" "time" "git.erbosoft.com/amy/amsterdam/config" @@ -33,6 +32,12 @@ import ( be timed out as well as used to show the logged-in users. This is similar to the session support provided in J2EE servlets. */ +// DEFAULT_SESSION_EXPIRE is the default time in which sessions will expire. +const DEFAULT_SESSION_EXPIRE = 1 * time.Hour + +// The interval at which all sessions will be swept. +const SESSION_STORE_SWEEP_INTERVAL = 2 * time.Minute + // AmSessionOptions gives the options for the session. type AmSessionOptions struct { Path string @@ -244,11 +249,10 @@ func (sess *amSession) Hit() { // amSessionStore is the implementation structure for AmSessionStore. type amSessionStore struct { - mutex sync.RWMutex - sessions map[string]*amSession - maxEntries int - expiry time.Duration - sweepRunning atomic.Bool + mutex sync.RWMutex + sessions map[string]*amSession + maxEntries int + expiry time.Duration } // createAmSessionStore creates the session store. @@ -258,7 +262,6 @@ func createAmSessionStore(exp time.Duration) *amSessionStore { maxEntries: 0, expiry: exp, } - rc.sweepRunning.Store(true) return rc } @@ -339,9 +342,15 @@ func (st *amSessionStore) SessionInfo() (int, []string, int) { * tick - Channel that "pulses" periodically to run the task. * done - Channel we write to when we're done. */ -func (st *amSessionStore) sweep(tick <-chan time.Time, done chan bool) { - for range tick { - if st.sweepRunning.Load() { +func (st *amSessionStore) sweep(ctx context.Context, done chan bool) { + tkr := time.NewTicker(SESSION_STORE_SWEEP_INTERVAL) + for { + select { + case <-ctx.Done(): + tkr.Stop() + done <- true + return + case <-tkr.C: // phase 1 - identify expired sessions st.mutex.RLock() zap := make([]string, 0, len(st.sessions)) @@ -366,11 +375,8 @@ func (st *amSessionStore) sweep(tick <-chan time.Time, done chan bool) { } st.mutex.Unlock() } - } else { - break } } - done <- true } // sessionStore is the global session store. @@ -381,30 +387,21 @@ func setupSessionManager() func() { // get the time for the session to expire d, err := time.ParseDuration(config.GlobalConfig.Site.SessionExpire) if err != nil { - d, err = time.ParseDuration("1h") - if err != nil { - panic(err.Error()) - } + log.Errorf("invalid session timeout value: %s", config.GlobalConfig.Site.SessionExpire) + d = DEFAULT_SESSION_EXPIRE } // create session store sessionStore = createAmSessionStore(d) - // get the clock value to run sweeps - d, err = time.ParseDuration("1s") - if err != nil { - panic(err.Error()) - } - // set up the sweep runner - tkr := time.NewTicker(d) + ctx, cancel := context.WithCancel(context.Background()) done := make(chan bool) - go sessionStore.sweep(tkr.C, done) + go sessionStore.sweep(ctx, done) return func() { // stop the sweep runner - sessionStore.sweepRunning.Store(false) + cancel() <-done - tkr.Stop() } } diff --git a/ui/dialog.go b/ui/dialog.go index a4615a4..e060f3f 100644 --- a/ui/dialog.go +++ b/ui/dialog.go @@ -14,6 +14,7 @@ package ui import ( "embed" + "errors" "fmt" "io" "io/fs" @@ -120,8 +121,7 @@ func AmLoadDialog(name string) (*Dialog, error) { f, err = extDialogs.Open(fmt.Sprintf("%s.yaml", name)) if err != nil { f = nil - pe := err.(*fs.PathError) - if pe.Err == os.ErrInvalid || pe.Err == os.ErrNotExist { + if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrNotExist) { err = nil } } diff --git a/util/util.go b/util/util.go index e0e9d3c..0228de3 100644 --- a/util/util.go +++ b/util/util.go @@ -13,7 +13,6 @@ package util import ( - "net" "regexp" "strings" "time" @@ -172,17 +171,6 @@ func Map[A, B any](in []A, fn func(A) B) []B { return rc } -// MyIPAddress returns the local IP address of this machine. -func MyIPAddress() (net.IP, error) { - conn, err := net.Dial("udp", "8.8.8.8:80") - if err != nil { - return nil, err - } - defer conn.Close() - localAddr := conn.LocalAddr().(*net.UDPAddr) - return localAddr.IP, nil -} - // IIF is an "immediate-if" function returning its second argument if the first one is true, the third one if not. func IIF[A any](expr bool, v1, v2 A) A { if expr {