diff --git a/database/audit.go b/database/audit.go index c5a3cc3..92b49bd 100644 --- a/database/audit.go +++ b/database/audit.go @@ -18,16 +18,16 @@ import ( // AuditRecord holds an audit record instance. type AuditRecord struct { - Record int64 `db:"record"` - OnDate time.Time `db:"on_date"` - Event int32 `db:"event"` - Uid int32 `db:"uid"` - CommId int32 `db:"commid"` - IP *string `db:"ip"` - Data1 *string `db:"data1"` - Data2 *string `db:"data2"` - Data3 *string `db:"data3"` - Data4 *string `db:"data4"` + Record int64 `db:"record"` // audit record ID + OnDate time.Time `db:"on_date"` // timestamp of event + Event int32 `db:"event"` // ID of the event + Uid int32 `db:"uid"` // user that performed the event + CommId int32 `db:"commid"` // community associated with the event + IP *string `db:"ip"` // IP address associated with the event + Data1 *string `db:"data1"` // first data parameter + Data2 *string `db:"data2"` // second data parameter + Data3 *string `db:"data3"` // third data parameter + Data4 *string `db:"data4"` // fourth data parameter } // These are the audit record types. @@ -134,7 +134,7 @@ func auditWriter(workChan chan *AuditRecord, doneChan chan bool) { for ar := range workChan { err := ar.Store() if err != nil { - log.Errorf("dropped audit record on the floor: %v", err) + log.Errorf("dropped audit record (%+v) on the floor: %v", *ar, err) } } doneChan <- true diff --git a/database/community.go b/database/community.go index 78764ee..086999b 100644 --- a/database/community.go +++ b/database/community.go @@ -21,43 +21,44 @@ import ( "git.erbosoft.com/amy/amsterdam/util" lru "github.com/hashicorp/golang-lru" + "github.com/jmoiron/sqlx" "golang.org/x/text/language" ) // Community struct contains the high level data for a community. type Community struct { Mutex sync.RWMutex - Id int32 `db:"commid"` - CreateDate time.Time `db:"createdate"` - LastAccess *time.Time `db:"lastaccess"` - LastUpdate *time.Time `db:"lastupdate"` - ReadLevel uint16 `db:"read_lvl"` - WriteLevel uint16 `db:"write_lvl"` - CreateLevel uint16 `db:"create_lvl"` - DeleteLevel uint16 `db:"delete_lvl"` - JoinLevel uint16 `db:"join_lvl"` - ContactId int32 `db:"contactid"` - HostUid *int32 `db:"host_uid"` - CategoryId int32 `db:"catid"` - HideFromDirectory bool `db:"hide_dir"` - HideFromSearch bool `db:"hide_search"` - MembersOnly bool `db:"membersonly"` - IsAdmin bool `db:"is_admin"` - InitFeature int16 `db:"init_ftr"` - Name string `db:"commname"` - Language *string `db:"language"` - Synopsis *string `db:"synopsis"` - Rules *string `dd:"rules"` - JoinKey *string `db:"joinkey"` - Alias string `db:"alias"` + Id int32 `db:"commid"` // ID of the community + CreateDate time.Time `db:"createdate"` // timestamp for community creation + LastAccess *time.Time `db:"lastaccess"` // timestamp for last access + LastUpdate *time.Time `db:"lastupdate"` // timestamp for last update + ReadLevel uint16 `db:"read_lvl"` // level required to read + WriteLevel uint16 `db:"write_lvl"` // level required to write (change community attributes) + CreateLevel uint16 `db:"create_lvl"` // level required to create subobjects + DeleteLevel uint16 `db:"delete_lvl"` // level required to delete community + JoinLevel uint16 `db:"join_lvl"` // level required to join + ContactId int32 `db:"contactid"` // community's information as a contact info record ID + HostUid *int32 `db:"host_uid"` // UID of the host + CategoryId int32 `db:"catid"` // category ID for community + HideFromDirectory bool `db:"hide_dir"` // if set, community is hidden from the directory + HideFromSearch bool `db:"hide_search"` // if set, this community is hidden from search + MembersOnly bool `db:"membersonly"` // is this community open to members only? + IsAdmin bool `db:"is_admin"` // set if this is the admin community + InitFeature int16 `db:"init_ftr"` // initial feature? + Name string `db:"commname"` // community name + Language *string `db:"language"` // primary language of community, ISO format + Synopsis *string `db:"synopsis"` // community synopsis + Rules *string `dd:"rules"` // rules (kinda short) + JoinKey *string `db:"joinkey"` // join key (password) to join community + Alias string `db:"alias"` // community alias flags *util.OptionSet } // CommunityProperties represents a property entry for a community. type CommunityProperties struct { - Cid int32 `db:"cid"` - Index int32 `db:"ndx"` - Data *string `db:"data"` + Cid int32 `db:"cid"` // community ID + Index int32 `db:"ndx"` // property index + Data *string `db:"data"` // property value } // Community property indexes defined. @@ -225,7 +226,7 @@ func (c *Community) Membership(u *User) (bool, bool, uint16, error) { return false, false, uint16(0), err } -// MemberCount returns the number of members in the community, quietly. +// MemberCount returns the number of members in the community. func (c *Community) MemberCount(hidden bool) (int, error) { var rs *sql.Rows var err error @@ -346,21 +347,28 @@ func (c *Community) ListMembers(field int, oper int, term string, offset int, ma * Standard Go error status. */ func (c *Community) SetMembership(u *User, level uint16, locked bool, personUID int32, ipaddr string) error { + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() if level == 0 { - res, err := amdb.Exec("DELETE FROM commmember WHERE commid = ? AND uid = ?", c.Id, u.Uid) + res, err := tx.Exec("DELETE FROM commmember WHERE commid = ? AND uid = ?", c.Id, u.Uid) if err != nil { return err } stuffMembership(c.Id, u.Uid, false, false, 0) ra, err := res.RowsAffected() if err == nil && ra > 0 { - err = AmOnUserLeaveCommunityServices(c, u) + err = AmOnUserLeaveCommunityServices(tx, c, u) if err != nil { return err } } } else { - rs, err := amdb.Query("SELECT granted_lvl, locked FROM commmember WHERE commid = ? AND uid = ?", c.Id, u.Uid) + rs, err := tx.Query("SELECT granted_lvl, locked FROM commmember WHERE commid = ? AND uid = ?", c.Id, u.Uid) if err != nil { return err } @@ -369,7 +377,7 @@ func (c *Community) SetMembership(u *User, level uint16, locked bool, personUID var lockStatus bool rs.Scan(&oldLevel, &lockStatus) if level != oldLevel || lockStatus != locked { - _, err := amdb.Exec("UPDATE commmember SET granted_lvl = ?, locked = ? WHERE commid = ? AND uid = ?", + _, err := tx.Exec("UPDATE commmember SET granted_lvl = ?, locked = ? WHERE commid = ? AND uid = ?", level, locked, c.Id, u.Uid) if err != nil { return err @@ -377,19 +385,19 @@ func (c *Community) SetMembership(u *User, level uint16, locked bool, personUID stuffMembership(c.Id, u.Uid, true, locked, level) } } else { - _, err := amdb.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, ?)", + _, err := tx.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, ?)", c.Id, u.Uid, level, locked) if err != nil { return err } stuffMembership(c.Id, u.Uid, true, locked, level) - err = AmOnUserJoinCommunityServices(c, u) + err = AmOnUserJoinCommunityServices(tx, c, u) if err != nil { return err } } } - err := c.TouchUpdate() + err := c.TouchUpdateTx(tx) if err == nil { ar := AmNewAudit(AuditCommunitySetMembership, personUID, ipaddr, fmt.Sprintf("cid=%d", c.Id), fmt.Sprintf("uid=%d", u.Uid), fmt.Sprintf("level=%d", level)) @@ -534,13 +542,13 @@ func (c *Community) Touch() error { return err } -// TouchUpdate updates the last access and last update times of the community. -func (c *Community) TouchUpdate() error { +// TouchUpdateTx updates the last access and last update times of the community. +func (c *Community) TouchUpdateTx(tx *sqlx.Tx) error { c.Mutex.Lock() defer c.Mutex.Unlock() - _, err := amdb.Exec("UPDATE communities SET lastaccess = NOW(), lastupdate = NOW() WHERE commid = ?", c.Id) + _, err := tx.Exec("UPDATE communities SET lastaccess = NOW(), lastupdate = NOW() WHERE commid = ?", c.Id) if err == nil { - rs, err := amdb.Query("SELECT lastaccess, lastupdate FROM communities WHERE commid = ?", c.Id) + rs, err := tx.Query("SELECT lastaccess, lastupdate FROM communities WHERE commid = ?", c.Id) if err != nil { rs.Next() var na, nu time.Time @@ -552,6 +560,19 @@ func (c *Community) TouchUpdate() error { return err } +// TouchUpdateTx updates the last access and last update times of the community. +func (c *Community) TouchUpdate() error { + tx := amdb.MustBegin() + err := c.TouchUpdateTx(tx) + if err != nil { + err = tx.Commit() + } + if err != nil { + tx.Rollback() + } + return err +} + /* AmGetCommunity returns a reference to the specified community. * Parameters: * id - The ID of the community. @@ -580,6 +601,35 @@ func AmGetCommunity(id int32) (*Community, error) { return rc.(*Community), nil } +/* AmGetCommunityTx returns a reference to the specified community, in a transaction. + * Parameters: + * tx - The transaction to use. + * id - The ID of the community. + * Returns: + * Pointer to Community containing community data, or nil + * Standard Go error status + */ +func AmGetCommunityTx(tx *sqlx.Tx, id int32) (*Community, error) { + getCommunityMutex.Lock() + defer getCommunityMutex.Unlock() + rc, ok := communityCache.Get(id) + if !ok { + var dbdata []Community + err := tx.Select(&dbdata, "SELECT * from communities WHERE commid = ?", id) + if err != nil { + return nil, err + } + if len(dbdata) == 0 { + return nil, fmt.Errorf("community with ID %d not found", id) + } else if len(dbdata) > 1 { + return nil, fmt.Errorf("AmGetCommunity(%d): too many responses(%d)", id, len(dbdata)) + } + rc = &(dbdata[0]) + communityCache.Add(id, rc) + } + return rc.(*Community), nil +} + /* AmGetCommunityByAlias returns a reference to the specified community. * Parameters: * alias - The alias for the community. @@ -601,6 +651,28 @@ func AmGetCommunityByAlias(alias string) (*Community, error) { return nil, err } +/* AmGetCommunityByAliasTx returns a reference to the specified community, within a transaction. + * Parameters: + * tx - The transaction to use. + * alias - The alias for the community. + * Returns: + * Pointer to Community containing community data, or nil + * Standard Go error status (nil if community not found) + */ +func AmGetCommunityByAliasTx(tx *sqlx.Tx, alias string) (*Community, error) { + rs, err := tx.Query("SELECT commid FROM communities WHERE alias = ?", alias) + if err == nil { + if rs.Next() { + var cid int32 + rs.Scan(&cid) + return AmGetCommunityTx(tx, cid) + } else { + return nil, nil + } + } + return nil, err +} + /* AmGetCommunityFromParam returns a reference to the specified community based on the parameter. * If the parameter is numeric, it's interpreted as a community ID. Otherwise, it's interpreted * as a community alias. @@ -676,20 +748,21 @@ func AmGetCommunityAccessLevel(uid int32, commid int32) (uint16, error) { /* AmAutoJoinCommunities joins the specified user to any communities they're not yet a part of. * Parameters: + * tx - The current transaction to be used for database access. * user - The user to be auto-joined to communities. * Returns: * Standard Go error status. */ -func AmAutoJoinCommunities(user *User) error { +func AmAutoJoinCommunities(tx *sqlx.Tx, user *User) error { // get list of current communities var current []int32 = make([]int32, 0) - err := amdb.Select(¤t, "SELECT commid FROM commmember WHERE uid = ?", user.Uid) + err := tx.Select(¤t, "SELECT commid FROM commmember WHERE uid = ?", user.Uid) if err != nil { return err } // look for candidate communities - rows, err := amdb.Queryx(`SELECT m.commid, m.locked FROM users u, communities c, commmember m + rows, err := tx.Queryx(`SELECT m.commid, m.locked FROM users u, communities c, commmember m WHERE m.uid = u.uid AND m.commid = c.commid AND u.is_anon = 1 AND c.join_lvl <= ?`, user.BaseLevel) if err == nil { defer rows.Close() @@ -699,7 +772,7 @@ func AmAutoJoinCommunities(user *User) error { var lock bool rows.Scan(&cid, &lock) if !slices.Contains(current, cid) { - _, err = amdb.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, ?)", + _, err = tx.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, ?)", cid, user.Uid, grantLevel, lock) if err != nil { break @@ -806,9 +879,16 @@ func AmCreateCommunity(name string, alias string, hostUid int32, language *strin defer func() { AmStoreAudit(ar) }() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() // validate alias does not already exist - rs, err := amdb.Query("SELECT commid FROM communities WHERE alias = ?", alias) + rs, err := tx.Query("SELECT commid FROM communities WHERE alias = ?", alias) if err != nil { return nil, err } @@ -817,7 +897,7 @@ func AmCreateCommunity(name string, alias string, hostUid int32, language *strin } // establish the community record - _, err = amdb.Exec(`INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl, + _, err = tx.Exec(`INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl, create_lvl, delete_lvl, join_lvl, host_uid, hide_dir, hide_search, commname, language, synopsis, rules, joinkey, alias) VALUES (NOW(), NOW(), NOW(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, AmRoleList("Community.Read").Default().Level(), AmRoleList("Community.Write").Default().Level(), @@ -829,7 +909,7 @@ func AmCreateCommunity(name string, alias string, hostUid int32, language *strin } // Read back the community, which also puts it in the cache. - comm, err := AmGetCommunityByAlias(alias) + comm, err := AmGetCommunityByAliasTx(tx, alias) if err != nil { return nil, err } else if comm == nil { @@ -838,7 +918,7 @@ func AmCreateCommunity(name string, alias string, hostUid int32, language *strin // Ensure the new host has host privileges in the community. The host's membership is "locked" so they // can't unjoin and leave the community hostless. - _, err = amdb.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, 1)", comm.Id, hostUid, + _, err = tx.Exec("INSERT INTO commmember (commid, uid, granted_lvl, locked) VALUES (?, ?, ?, 1)", comm.Id, hostUid, AmDefaultRole("Community.Creator").Level()) if err != nil { return nil, err @@ -846,11 +926,17 @@ func AmCreateCommunity(name string, alias string, hostUid int32, language *strin stuffMembership(comm.Id, hostUid, true, true, AmDefaultRole("Community.Creator").Level()) // Establish the community services. - err = AmEstablishCommunityServices(comm) + err = AmEstablishCommunityServices(tx, comm) if err != nil { return nil, err } + err = tx.Commit() + if err != nil { + return nil, err + } + success = true + // operation was a success - add an audit record ar = AmNewAudit(AuditCommunityCreate, hostUid, remoteIP, fmt.Sprintf("id=%d", comm.Id), fmt.Sprintf("name=%s", comm.Name), fmt.Sprintf("alias=%s", comm.Alias)) diff --git a/database/conference.go b/database/conference.go index 52a87c0..8de4521 100644 --- a/database/conference.go +++ b/database/conference.go @@ -16,6 +16,7 @@ import ( "time" lru "github.com/hashicorp/golang-lru" + "github.com/jmoiron/sqlx" ) // Conference struct is the top-level structure for a conference. @@ -170,7 +171,7 @@ func (c *Conference) Settings(u *User) (*ConferenceSettings, error) { } // TouchRead updates the "last posted" date/time in the conference for the user. -func (c *Conference) TouchRead(u *User) (*ConferenceSettings, error) { +func (c *Conference) TouchRead(tx *sqlx.Tx, u *User) (*ConferenceSettings, error) { cs, err := c.Settings(u) if err != nil { return nil, err @@ -181,10 +182,10 @@ func (c *Conference) TouchRead(u *User) (*ConferenceSettings, error) { if cerr != nil { return nil, cerr } - amdb.Exec("INSERT INTO confsettings (confid, uid, default_pseud, last_read) VALUES (?, ?, ?, NOW())", + _, err = tx.Exec("INSERT INTO confsettings (confid, uid, default_pseud, last_read) VALUES (?, ?, ?, NOW())", c.ConfId, u.Uid, ci.FullName(false)) } else { - _, err = amdb.Exec("UPDATE confsettings SET last_read = NOW() WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) + _, err = tx.Exec("UPDATE confsettings SET last_read = NOW() WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) } if err == nil { cs, err = c.Settings(u) // reread to get updated or inserted values @@ -197,7 +198,7 @@ func (c *Conference) TouchRead(u *User) (*ConferenceSettings, error) { } // TouchPost updates the "last posted" date/time in the conference for the user. -func (c *Conference) TouchPost(u *User, lastPost time.Time) (*ConferenceSettings, error) { +func (c *Conference) TouchPost(tx *sqlx.Tx, u *User, lastPost time.Time) (*ConferenceSettings, error) { cs, err := c.Settings(u) if err != nil { return nil, err @@ -216,10 +217,10 @@ func (c *Conference) TouchPost(u *User, lastPost time.Time) (*ConferenceSettings LastRead: &lastPost, LastPost: &lastPost, } - _, err = amdb.Exec("INSERT INTO confsettings (confid, uid, default_pseud, last_read, last_post) VALUES (?, ?, ?, ?, ?)", + _, err = tx.Exec("INSERT INTO confsettings (confid, uid, default_pseud, last_read, last_post) VALUES (?, ?, ?, ?, ?)", c.ConfId, u.Uid, defaultPseud, lastPost, lastPost) } else { - _, err = amdb.Exec("UPDATE confsettings SET last_post = ? WHERE confid = ? AND uid = ?", lastPost, c.ConfId, u.Uid) + _, err = tx.Exec("UPDATE confsettings SET last_post = ? WHERE confid = ? AND uid = ?", lastPost, c.ConfId, u.Uid) cs.LastPost = &lastPost } if err != nil { diff --git a/database/services.go b/database/services.go index 301c06e..2b48f7b 100644 --- a/database/services.go +++ b/database/services.go @@ -15,33 +15,34 @@ import ( "sync" lru "github.com/hashicorp/golang-lru" + "github.com/jmoiron/sqlx" "gopkg.in/yaml.v3" ) // ServiceVTable is a serioes of functions called for services on specific events. type ServiceVTable interface { - OnNewCommunity(*Community) error - OnDeleteCommunity(int32) error - OnUserJoinCommunity(*Community, *User) error - OnUserLeaveCommunity(*Community, *User) error + OnNewCommunity(*sqlx.Tx, *Community) error + OnDeleteCommunity(*sqlx.Tx, int32) error + OnUserJoinCommunity(*sqlx.Tx, *Community, *User) error + OnUserLeaveCommunity(*sqlx.Tx, *Community, *User) error } // emptyServiceVTable is a default ServiceVTable that does nothing. type emptyServiceVTable struct{} -func (*emptyServiceVTable) OnNewCommunity(*Community) error { +func (*emptyServiceVTable) OnNewCommunity(*sqlx.Tx, *Community) error { return nil } -func (*emptyServiceVTable) OnDeleteCommunity(int32) error { +func (*emptyServiceVTable) OnDeleteCommunity(*sqlx.Tx, int32) error { return nil } -func (*emptyServiceVTable) OnUserJoinCommunity(*Community, *User) error { +func (*emptyServiceVTable) OnUserJoinCommunity(*sqlx.Tx, *Community, *User) error { return nil } -func (*emptyServiceVTable) OnUserLeaveCommunity(*Community, *User) error { +func (*emptyServiceVTable) OnUserLeaveCommunity(*sqlx.Tx, *Community, *User) error { return nil } @@ -150,19 +151,50 @@ func AmGetCommunityServices(cid int32) ([]*ServiceDef, error) { return rc.([]*ServiceDef), nil } +/* AmGetCommunityServices returns all the community service definitions for a community, using a transaction. + * Parameters: + * tx - Transaction to be used. + * cid - Community ID to get services for. + * Returns: + * Array of ServiceDef pointers for the community's services. + * Standard Go error status. + */ +func AmGetCommunityServicesTx(tx *sqlx.Tx, cid int32) ([]*ServiceDef, error) { + servicesCacheMutex.Lock() + defer servicesCacheMutex.Unlock() + rc, ok := servicesCache.Get(cid) + if !ok { + rs, err := tx.Query("SELECT ftr_code FROM commftrs WHERE commid = ?", cid) + if err != nil { + return nil, err + } + dom := serviceRoot.byName["community"] + a := make([]*ServiceDef, 0, len(dom.Services)) + for rs.Next() { + var ndx int16 + rs.Scan(&ndx) + a = append(a, dom.byIndex[ndx]) + } + servicesCache.Add(cid, a) + rc = a + } + return rc.([]*ServiceDef), nil +} + /* AmEstablishCommunityServices establishes the service (feature) records for a new community, * and allows the services to establish themselves. * Parameters: + * tx - The transaction to use. * c - The new community. * Returns: * Standard Go error status. */ -func AmEstablishCommunityServices(c *Community) error { +func AmEstablishCommunityServices(tx *sqlx.Tx, c *Community) error { dom := serviceRoot.byName["community"] a := make([]*ServiceDef, 0, len(dom.Services)) for i, svc := range dom.Services { if svc.Default { - _, err := amdb.Exec("INSERT INTO commftrs (commid, ftr_code) VALUES (?, ?)", c.Id, svc.Index) + _, err := tx.Exec("INSERT INTO commftrs (commid, ftr_code) VALUES (?, ?)", c.Id, svc.Index) if err != nil { return err } @@ -173,7 +205,7 @@ func AmEstablishCommunityServices(c *Community) error { servicesCache.Add(c.Id, a) servicesCacheMutex.Unlock() for _, svc := range a { - err := svc.vtable.OnNewCommunity(c) + err := svc.vtable.OnNewCommunity(tx, c) if err != nil { return err } @@ -184,22 +216,23 @@ func AmEstablishCommunityServices(c *Community) error { /* AmDeleteCommunityServices cleans up all services associated with a community that has gone away, * and then cleans up the service records. * Parameters: + * tx - The transaction to use. * cid - The ID of the departing community. * Returns: * Standard Go error status. */ -func AmDeleteCommunityServices(cid int32) error { +func AmDeleteCommunityServices(tx *sqlx.Tx, cid int32) error { arr, err := AmGetCommunityServices(cid) if err == nil { for _, svc := range arr { - err = svc.vtable.OnDeleteCommunity(cid) + err = svc.vtable.OnDeleteCommunity(tx, cid) if err != nil { break } } } if err == nil { - _, err = amdb.Exec("DELETE FROM commftrs WHERE commid = ?", cid) + _, err = tx.Exec("DELETE FROM commftrs WHERE commid = ?", cid) servicesCacheMutex.Lock() servicesCache.Remove(cid) servicesCacheMutex.Unlock() @@ -209,16 +242,17 @@ func AmDeleteCommunityServices(cid int32) error { /* AmOnUserJoinCommunityServices gives services a chance to update themselves when a user joins a community. * Parameters: + * tx - The current database transaction. * c - The community that is being joined. * u - The user leaving that community. * Returns: * Standard Go error status. */ -func AmOnUserJoinCommunityServices(c *Community, u *User) error { - arr, err := AmGetCommunityServices(c.Id) +func AmOnUserJoinCommunityServices(tx *sqlx.Tx, c *Community, u *User) error { + arr, err := AmGetCommunityServicesTx(tx, c.Id) if err == nil { for _, svc := range arr { - err = svc.vtable.OnUserJoinCommunity(c, u) + err = svc.vtable.OnUserJoinCommunity(tx, c, u) if err != nil { break } @@ -229,16 +263,17 @@ func AmOnUserJoinCommunityServices(c *Community, u *User) error { /* AmOnUserLeaveCommunityServices gives services a chance to update themselves when a user leaves a community. * Parameters: + * tx - The current database transaction. * c - The community that is being left. * u - The user leaving that community. * Returns: * Standard Go error status. */ -func AmOnUserLeaveCommunityServices(c *Community, u *User) error { - arr, err := AmGetCommunityServices(c.Id) +func AmOnUserLeaveCommunityServices(tx *sqlx.Tx, c *Community, u *User) error { + arr, err := AmGetCommunityServicesTx(tx, c.Id) if err == nil { for _, svc := range arr { - err = svc.vtable.OnUserLeaveCommunity(c, u) + err = svc.vtable.OnUserLeaveCommunity(tx, c, u) if err != nil { break } diff --git a/database/sidebox.go b/database/sidebox.go index 18463d1..b058519 100644 --- a/database/sidebox.go +++ b/database/sidebox.go @@ -9,6 +9,8 @@ // The database package contains database management and storage logic. package database +import "github.com/jmoiron/sqlx" + type Sidebox struct { Uid int32 `db:"uid"` Boxid int32 `db:"boxid"` @@ -17,12 +19,12 @@ type Sidebox struct { } // copySideboxes copies sideboxes from one user to another. -func copySideboxes(toUid int32, fromUid int32) error { +func copySideboxes(tx *sqlx.Tx, toUid int32, fromUid int32) error { sbox := make([]Sidebox, 0, 3) - err := amdb.Select(&sbox, "SELECT * from sideboxes WHERE uid = ?", fromUid) + err := tx.Select(&sbox, "SELECT * from sideboxes WHERE uid = ?", fromUid) if err == nil { for _, sb := range sbox { - _, err := amdb.Exec("INSERT INTO sideboxes (uid, boxid, sequence, param) VALUES (?, ?, ?, ?)", toUid, sb.Boxid, sb.Sequence, sb.Param) + _, err := tx.Exec("INSERT INTO sideboxes (uid, boxid, sequence, param) VALUES (?, ?, ?, ?)", toUid, sb.Boxid, sb.Sequence, sb.Param) if err != nil { break } diff --git a/database/topic.go b/database/topic.go index ce82d36..77d10fc 100644 --- a/database/topic.go +++ b/database/topic.go @@ -14,6 +14,8 @@ import ( "fmt" "strings" "time" + + "github.com/jmoiron/sqlx" ) // Topic is the top-level structure detailing topics. @@ -97,6 +99,29 @@ func AmGetTopic(topicId int32) (*Topic, error) { return &(dbdata[0]), nil } +/* AmGetTopic retrieves a topic by ID, in a transaction. + * Parameters: + * tx - The transaction to use. + * topicId - ID of the topic to retrieve. + * Returns: + * The topic pointer, or nil. + * Standard Go error status. + */ +func AmGetTopicTx(tx *sqlx.Tx, topicId int32) (*Topic, error) { + var dbdata []Topic + err := tx.Select(&dbdata, "SELECT * FROM topics WHERE topicid = ?", topicId) + if err != nil { + return nil, err + } + if len(dbdata) == 0 { + return nil, fmt.Errorf("topic %d not found", topicId) + } + if len(dbdata) > 1 { + return nil, fmt.Errorf("AmGetTopic(%d): too many responses (%d)", topicId, len(dbdata)) + } + return &(dbdata[0]), nil +} + // View and sort constants for AmListTopics. const ( TopicViewAll = 0 // list all topics @@ -243,18 +268,25 @@ func AmNewTopic(conf *Conference, user *User, title string, zeroPostPseud string defer func() { AmStoreAudit(ar) }() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() unlock := true - amdb.Exec("LOCK TABLES confs WRITE, topics WRITE, topicsettings WRITE, posts WRITE, postdata WRITE;") + tx.Exec("LOCK TABLES confs WRITE, topics WRITE, topicsettings WRITE, posts WRITE, postdata WRITE;") defer func() { if unlock { - amdb.Exec("UNLOCK TABLES;") + tx.Exec("UNLOCK TABLES;") } }() // Insert the new topic into the database. conf.Mutex.Lock() - rs, err := amdb.Exec("INSERT INTO topics (confid, num, creator_uid, createdate, lastupdate, name) VALUES (?, ?, ?, NOW(), NOW(), ?)", + rs, err := tx.Exec("INSERT INTO topics (confid, num, creator_uid, createdate, lastupdate, name) VALUES (?, ?, ?, NOW(), NOW(), ?)", conf.ConfId, conf.TopTopic+1, user.Uid, title) if err != nil { conf.Mutex.Unlock() @@ -267,14 +299,14 @@ func AmNewTopic(conf *Conference, user *User, title string, zeroPostPseud string return nil, err } // Get the topic. - topic, err := AmGetTopic(int32(xid)) + topic, err := AmGetTopicTx(tx, int32(xid)) if err != nil { conf.Mutex.Unlock() return nil, err } // Update the conference to set the last update and top topic. - _, err = amdb.Exec("UPDATE confs SET lastupdate = ?, top_topic = ? WHERE confid = ?", topic.CreateDate, conf.TopTopic+1, conf.ConfId) + _, err = tx.Exec("UPDATE confs SET lastupdate = ?, top_topic = ? WHERE confid = ?", topic.CreateDate, conf.TopTopic+1, conf.ConfId) if err != nil { conf.Mutex.Unlock() return nil, err @@ -284,7 +316,7 @@ func AmNewTopic(conf *Conference, user *User, title string, zeroPostPseud string conf.Mutex.Unlock() // Add the "header record" for the first post. - rs, err = amdb.Exec("INSERT INTO posts (topicid, num, linecount, creator_uid, posted, pseud) VALUES (?, 0, ?, ?, ?, ?)", + rs, err = tx.Exec("INSERT INTO posts (topicid, num, linecount, creator_uid, posted, pseud) VALUES (?, 0, ?, ?, ?, ?)", topic.TopicId, zeroPostLines, user.Uid, topic.CreateDate, zeroPostPseud) if err != nil { return nil, err @@ -294,27 +326,33 @@ func AmNewTopic(conf *Conference, user *User, title string, zeroPostPseud string return nil, err } // Add the post data. - _, err = amdb.Exec("INSERT INTO postdata (postid, data) VALUES (?, ?)", int32(xid), zeroPost) + _, err = tx.Exec("INSERT INTO postdata (postid, data) VALUES (?, ?)", int32(xid), zeroPost) if err != nil { return nil, err } // Add a new topic settings record for the user, too. - _, err = amdb.Exec("INSERT INTO topicsettings (topicid, uid, last_post) VALUES (?, ?, ?)", + _, err = tx.Exec("INSERT INTO topicsettings (topicid, uid, last_post) VALUES (?, ?, ?)", topic.TopicId, user.Uid, topic.CreateDate) if err != nil { return nil, err } - amdb.Exec("UNLOCK TABLES;") + tx.Exec("UNLOCK TABLES;") unlock = false // update the "last posted" date in the conference settings - _, err = conf.TouchPost(user, topic.CreateDate) + _, err = conf.TouchPost(tx, user, topic.CreateDate) if err != nil { return nil, err } + err = tx.Commit() + if err != nil { + return nil, err + } + success = true + // create audit record ar = AmNewAudit(AuditConferenceCreateTopic, user.Uid, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), fmt.Sprintf("num=%d", topic.Number), fmt.Sprintf("name=%s", topic.Name)) diff --git a/database/user.go b/database/user.go index c86163a..37e62a3 100644 --- a/database/user.go +++ b/database/user.go @@ -22,6 +22,7 @@ import ( "git.erbosoft.com/amy/amsterdam/util" lru "github.com/hashicorp/golang-lru" + "github.com/jmoiron/sqlx" "github.com/klauspost/lctime" log "github.com/sirupsen/logrus" "golang.org/x/text/language" @@ -37,12 +38,12 @@ type UserPrefs struct { // ReadLocale reads the locale out of the prefs, adjusting for Go use. func (p *UserPrefs) ReadLocale() string { - return strings.Replace(p.LocaleID, "_", "-", -1) + return strings.ReplaceAll(p.LocaleID, "_", "-") } // WriteLocale writes the locale into the prefs, adjusting for backward compatibility. func (p *UserPrefs) WriteLocale(loc string) { - p.LocaleID = strings.Replace(loc, "-", "_", -1) + p.LocaleID = strings.ReplaceAll(loc, "-", "_") } // Clone duplicates the user preferences. @@ -122,9 +123,9 @@ type User struct { // UserProperties represents a property entry for a user. type UserProperties struct { - Uid int32 `db:"uid"` - Index int32 `db:"ndx"` - Data *string `db:"data"` + Uid int32 `db:"uid"` // UID of user + Index int32 `db:"ndx"` // index of property + Data *string `db:"data"` // property data } // User property indexes defined. @@ -239,6 +240,13 @@ func (u *User) ConfirmEMailAddress(confnum int32, remoteIP string) error { defer func() { AmStoreAudit(ar) }() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() log.Debugf("ConfirmEMailAddress for UID %d", u.Uid) u.Mutex.Lock() @@ -252,14 +260,18 @@ func (u *User) ConfirmEMailAddress(confnum int32, remoteIP string) error { ar = AmNewAudit(AuditVerifyEmailFail, u.Uid, remoteIP, "Invalid confirmation number") return errors.New("confirmation number is incorrect. Please try again") } - _, err := amdb.Exec("UPDATE users SET verify_email = 1, base_lvl = ? WHERE uid = ?", + _, err := tx.Exec("UPDATE users SET verify_email = 1, base_lvl = ? WHERE uid = ?", AmDefaultRole("Global.AfterVerify").Level(), u.Uid) if err == nil { u.VerifyEMail = true u.BaseLevel = AmDefaultRole("Global.AfterVerify").Level() - err = AmAutoJoinCommunities(u) + err = AmAutoJoinCommunities(tx, u) if err == nil { - ar = AmNewAudit(AuditVerifyEmailOK, u.Uid, remoteIP) + err = tx.Commit() + if err == nil { + success = true + ar = AmNewAudit(AuditVerifyEmailOK, u.Uid, remoteIP) + } } } return err @@ -399,16 +411,50 @@ func AmGetUser(uid int32) (*User, error) { return rc.(*User), err } -/* AmGetUserByName returns a reference to the specified user. +/* AmGetUserTx returns a reference to the specified user inside a transaction. * Parameters: - * name - The username of the user. + * tx - The transaction we're in. + * uid - The UID of the user. * Returns: * Pointer to User containing user data, or nil * Standard Go error status */ -func AmGetUserByName(name string) (*User, error) { +func AmGetUserTx(tx *sqlx.Tx, uid int32) (*User, error) { + var err error = nil + getUserMutex.Lock() + defer getUserMutex.Unlock() + rc, ok := userCache.Get(uid) + if !ok { + var dbdata []User + err = tx.Select(&dbdata, "SELECT * from users WHERE uid = ?", uid) + if err != nil { + return nil, err + } + if len(dbdata) > 1 { + return nil, fmt.Errorf("AmGetUser(%d): too many responses(%d)", uid, len(dbdata)) + } + rc = &(dbdata[0]) + userCache.Add(uid, rc) + } + return rc.(*User), err +} + +/* AmGetUserByName returns a reference to the specified user. + * Parameters: + * name - The username of the user. + * tx - If this is not nil, use this transaction. + * Returns: + * Pointer to User containing user data, or nil + * Standard Go error status + */ +func AmGetUserByName(name string, tx *sqlx.Tx) (*User, error) { var dbdata []User - err := amdb.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name) + var err error + if tx != nil { + err = tx.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name) + } else { + err = amdb.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name) + } if err != nil { return nil, err } @@ -416,12 +462,12 @@ func AmGetUserByName(name string) (*User, error) { return nil, fmt.Errorf("AmGetUserByName(\"%s\"): too many responses(%d)", name, len(dbdata)) } getUserMutex.Lock() - defer getUserMutex.Unlock() rc, ok := userCache.Get(dbdata[0].Uid) if !ok { rc = &(dbdata[0]) userCache.Add(dbdata[0].Uid, rc) } + getUserMutex.Unlock() return rc.(*User), nil } @@ -485,11 +531,11 @@ func hashPassword(password string) string { } // touchUser updates the last access time for the user. -func touchUser(user *User) { +func touchUser(tx *sqlx.Tx, user *User) { user.Mutex.Lock() defer user.Mutex.Unlock() moment := time.Now().UTC() - _, _ = amdb.Exec("UPDATE user SET lastaccess = ? WHERE uid = ?", moment, user.Uid) + tx.Exec("UPDATE user SET lastaccess = ? WHERE uid = ?", moment, user.Uid) user.LastAccess = &moment } @@ -508,8 +554,15 @@ func AmAuthenticateUser(name string, password string, remoteIP string) (*User, e defer func() { AmStoreAudit(ar) }() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() - user, err := AmGetUserByName(name) + user, err := AmGetUserByName(name, tx) if err != nil { log.Error("...user not found") ar = AmNewAudit(AuditLoginFail, 0, remoteIP, fmt.Sprintf("Bad username: %s", name)) @@ -532,7 +585,12 @@ func AmAuthenticateUser(name string, password string, remoteIP string) (*User, e return nil, errors.New("the password you have specified is incorrect; please try again") } log.Debug("...authenticated") - touchUser(user) + touchUser(tx, user) + err = tx.Commit() + if err != nil { + return nil, err + } + success = true ar = AmNewAudit(AuditLoginOK, user.Uid, remoteIP) return user, nil } @@ -574,13 +632,20 @@ func AmAuthenticateUserByToken(authString string, remoteIP string) (*User, error defer func() { AmStoreAudit(ar) }() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() uid, token, err := crackAuthString(authString) if err != nil { return nil, fmt.Errorf("authString not valid, ignored: %v", err) } var user *User - user, err = AmGetUser(uid) + user, err = AmGetUserTx(tx, uid) if err != nil { log.Error("...user not found") ar = AmNewAudit(AuditLoginFail, 0, remoteIP, fmt.Sprintf("Bad uid: %d", uid)) @@ -603,7 +668,12 @@ func AmAuthenticateUserByToken(authString string, remoteIP string) (*User, error return nil, errors.New("token mismatch") } log.Debug("...authenticated") - touchUser(user) + touchUser(tx, user) + err = tx.Commit() + if err != nil { + return nil, err + } + success = true ar = AmNewAudit(AuditLoginOK, user.Uid, remoteIP) return user, nil } @@ -624,17 +694,24 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim defer func() { AmStoreAudit(ar) }() - + anon, _ := getAnonUserID() + success := false + tx := amdb.MustBegin() + defer func() { + if !success { + tx.Rollback() + } + }() unlock := true - amdb.Exec("LOCK TABLES users WRITE, userprefs WRITE, propuser WRITE, commmember WRITE, sideboxes WRITE, confhotlist WRITE;") + tx.Exec("LOCK TABLES users WRITE, userprefs WRITE, propuser WRITE, commmember WRITE, sideboxes WRITE, confhotlist WRITE;") defer func() { if unlock { - amdb.Exec("UNLOCK TABLES;") + tx.Exec("UNLOCK TABLES;") } }() // Test if the user name is already taken. - rs, err := amdb.Query("SELECT uid FROM users WHERE username = ?", username) + rs, err := tx.Query("SELECT uid FROM users WHERE username = ?", username) if err != nil { return nil, err } else if rs.Next() { @@ -643,7 +720,7 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim } // Insert the user record. - _, err2 := amdb.Exec(`INSERT INTO users (username, passhash, verify_email, lockout, email_confnum, + _, err2 := tx.Exec(`INSERT INTO users (username, passhash, verify_email, lockout, email_confnum, base_lvl, created, lastaccess, passreminder, description, dob) VALUES (?, ?, 0, 0, ?, ?, NOW(), NOW(), ?, '', ?)`, username, hashPassword(password), util.GenerateRandomConfirmationNumber(), AmDefaultRole("Global.NewUser").Level(), reminder, dob) @@ -651,49 +728,54 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim return nil, err2 } // Read back the user, which also puts it in the cache. - user, err3 := AmGetUserByName(username) + user, err3 := AmGetUserByName(username, tx) if err3 != nil { return nil, err3 } log.Debugf("...created new user \"%s\" with UID %d", username, user.Uid) // add user preferences - _, err = amdb.Exec("INSERT INTO userprefs (uid) VALUES (?)", user.Uid) + _, err = tx.Exec("INSERT INTO userprefs (uid) VALUES (?)", user.Uid) if err != nil { return nil, err } // add user properties props := make([]UserProperties, 0) - anon, _ := getAnonUserID() - err = amdb.Select(&props, "SELECT * FROM propuser WHERE uid = ?", anon) + err = tx.Select(&props, "SELECT * FROM propuser WHERE uid = ?", anon) if err != nil { return nil, err } for _, p := range props { - _, err := amdb.Exec("INSERT INTO propuser (uid, ndx, data) VALUES (?, ?, ?)", user.Uid, p.Index, p.Data) + _, err := tx.Exec("INSERT INTO propuser (uid, ndx, data) VALUES (?, ?, ?)", user.Uid, p.Index, p.Data) if err != nil { return nil, err } } // add user sideboxes - err = copySideboxes(user.Uid, anon) + err = copySideboxes(tx, user.Uid, anon) if err != nil { return nil, err } - amdb.Exec("UNLOCK TABLES;") + tx.Exec("UNLOCK TABLES;") unlock = false // auto-join communities - err = AmAutoJoinCommunities(user) + err = AmAutoJoinCommunities(tx, user) if err != nil { return nil, err } // TODO: copy conference hotlists + err = tx.Commit() + if err != nil { + return nil, err + } + success = true + // operation was a success - add an audit record ar = AmNewAudit(AuditAccountCreated, user.Uid, remoteIP) return user, nil diff --git a/htmlcheck/rewriter.go b/htmlcheck/rewriter.go index 63cbffa..69edeb0 100644 --- a/htmlcheck/rewriter.go +++ b/htmlcheck/rewriter.go @@ -215,7 +215,7 @@ func (rw *userLinkRewriter) Rewrite(data string, svc rewriterServices) *markupDa return nil } - user, err := database.AmGetUserByName(data) + user, err := database.AmGetUserByName(data, nil) if err != nil || user == nil { return nil } diff --git a/login.go b/login.go index 4e8cc22..135d9c5 100644 --- a/login.go +++ b/login.go @@ -80,7 +80,7 @@ func Login(ctxt ui.AmContext) (string, any, error) { return dlg.RenderError(ctxt, "User name not specified.") } if action == "remind" { // Password Reminder button pressed - user, uerr := database.AmGetUserByName(username) + user, uerr := database.AmGetUserByName(username, nil) if uerr == nil { var ci *database.ContactInfo ci, uerr = user.ContactInfo() diff --git a/userdata.go b/userdata.go index 4a4caa9..30bfe9f 100644 --- a/userdata.go +++ b/userdata.go @@ -334,7 +334,7 @@ func ShowProfile(ctxt ui.AmContext) (string, any, error) { } // Gather the info on the current user. - user, err := database.AmGetUserByName(ctxt.URLParam("uname")) + user, err := database.AmGetUserByName(ctxt.URLParam("uname"), nil) if err != nil { ctxt.SetRC(http.StatusNotFound) return ui.ErrorPage(ctxt, err)