diff --git a/database/base.go b/database/base.go index 5258ed8..f9162cd 100644 --- a/database/base.go +++ b/database/base.go @@ -10,11 +10,13 @@ package database import ( + "context" "slices" "git.erbosoft.com/amy/amsterdam/config" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" + log "github.com/sirupsen/logrus" ) // amdb is the reference to the Amsterdam database. @@ -42,3 +44,36 @@ func SetupDb() (func(), error) { amdb.Close() }, err } + +/* transaction starts a transaction and returns functions for commit and rollback. The rollback + * function can be immediately deferred; if commit is called successfully, rollback becomes a no-op. + * Parameters: + * ctx - Standard Go error status. + * Returns: + * The sqlx transaction object + * The commit function (no parameters, returns error) + * The rollback function (no parameters or return) + */ +func transaction(ctx context.Context) (*sqlx.Tx, func() error, func()) { + tx := amdb.MustBeginTx(ctx, nil) + live := true + fCom := func() error { + var err error = nil + if live { + err = tx.Commit() + if err == nil { + live = false + } + } + return err + } + fRoll := func() { + if live { + if err := tx.Rollback(); err != nil { + log.Errorf("***ROLLBACK ERROR*** %v", err) + } + live = false + } + } + return tx, fCom, fRoll +} diff --git a/database/community.go b/database/community.go index cccd998..e7444a5 100644 --- a/database/community.go +++ b/database/community.go @@ -336,13 +336,9 @@ func (c *Community) ListMembers(ctx context.Context, field int, oper int, term s * Standard Go error status. */ func (c *Community) SetMembership(ctx context.Context, u *User, level uint16, locked bool, personUID int32, ipaddr string) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() + if level == 0 { res, err := tx.ExecContext(ctx, "DELETE FROM commmember WHERE commid = ? AND uid = ?", c.Id, u.Uid) if err != nil { @@ -382,11 +378,14 @@ func (c *Community) SetMembership(ctx context.Context, u *User, level uint16, lo return err } } - if err := c.TouchUpdateTx(ctx, tx); err == nil { - AmStoreAudit(AmNewCommAudit(AuditCommunitySetMembership, personUID, c.Id, ipaddr, fmt.Sprintf("cid=%d", c.Id), - fmt.Sprintf("uid=%d", u.Uid), fmt.Sprintf("level=%d", level))) + var err error + if err = c.TouchUpdateTx(ctx, tx); err == nil { + if err = commit(); err == nil { + AmStoreAudit(AmNewCommAudit(AuditCommunitySetMembership, personUID, c.Id, ipaddr, fmt.Sprintf("cid=%d", c.Id), + fmt.Sprintf("uid=%d", u.Uid), fmt.Sprintf("level=%d", level))) + } } - return nil + return err } /* TestPermission is shorthand that tests if a user has a permission with respect to the community. @@ -541,13 +540,13 @@ func (c *Community) TouchUpdateTx(ctx context.Context, tx *sqlx.Tx) error { // TouchUpdateTx updates the last access and last update times of the community. func (c *Community) TouchUpdate(ctx context.Context) error { - tx := amdb.MustBegin() + tx, commit, rollback := transaction(ctx) err := c.TouchUpdateTx(ctx, tx) if err != nil { - err = tx.Commit() + err = commit() } if err != nil { - tx.Rollback() + rollback() } return err } @@ -854,13 +853,8 @@ func AmSetCommunityProperty(ctx context.Context, cid int32, ndx int32, val *stri */ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid int32, language *string, synopsis *string, rules *string, joinkey *string, hideDirectory bool, hideSearch bool, remoteIP string) (*Community, error) { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // validate alias does not already exist row := tx.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias) @@ -907,10 +901,9 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i return nil, err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true // operation was a success - add an audit record AmStoreAudit(AmNewCommAudit(AuditCommunityCreate, hostUid, comm.Id, remoteIP, fmt.Sprintf("id=%d", comm.Id), diff --git a/database/conference.go b/database/conference.go index 336e935..1f9cb1e 100644 --- a/database/conference.go +++ b/database/conference.go @@ -588,13 +588,8 @@ func (c *Conference) Fixseen(ctx context.Context, u *User) error { if u.IsAnon { return nil } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Get a count of topics beforehand. row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM topics WHERE confid = ?", c.ConfId) @@ -634,10 +629,9 @@ func (c *Conference) Fixseen(ctx context.Context, u *User) error { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } @@ -657,13 +651,8 @@ func (c *Conference) GetCustomBlocks(ctx context.Context) (string, string, error // SetCustomBlocks sets the custom HTML blocks for this conference. func (c *Conference) SetCustomBlocks(ctx context.Context, topBlock, bottomBlock string) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM confcustom WHERE confid = ?", c.ConfId) ct := 0 err := row.Scan(&ct) @@ -678,10 +667,9 @@ func (c *Conference) SetCustomBlocks(ctx context.Context, topBlock, bottomBlock if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } @@ -815,7 +803,7 @@ func (c *Conference) Stats(ctx context.Context) (int, int, error) { // backgroundPurgeConference purges out all the conference information in the background. func backgroundPurgeConference(ctx context.Context, confid int32) error { // Purge out auxiliary conference tables first. - tx := amdb.MustBegin() + tx, commit, rollback := transaction(ctx) _, err := tx.ExecContext(ctx, "DELETE FROM confmember WHERE confid = ?", confid) if err != nil { log.Warnf("backgroundPurgeConference(%d): failed purging confmember: %v", confid, err) @@ -836,9 +824,9 @@ func backgroundPurgeConference(ctx context.Context, confid int32) error { if err != nil { log.Warnf("backgroundPurgeConference(%d): failed purging confcustom: %v", confid, err) } - err = tx.Commit() + err = commit() if err != nil { - tx.Rollback() + rollback() return err } @@ -851,13 +839,13 @@ func backgroundPurgeConference(ctx context.Context, confid int32) error { // Erase each topic in turn by calling two of the "delete topic" internal functions. for _, topicId := range topicIds { - tx := amdb.MustBegin() + tx, commit, rollback := transaction(ctx) err = eraseTopicRecords(ctx, tx, topicId) if err == nil { - err = tx.Commit() + err = commit() } if err != nil { - tx.Rollback() + rollback() return err } err = backgroundPurgeTopic(ctx, topicId) @@ -870,13 +858,8 @@ func backgroundPurgeConference(ctx context.Context, confid int32) error { // Delete unlinks this conference from the community, deleting it entirely if the last link is gone. func (c *Conference) Delete(ctx context.Context, comm *Community, u *User, ipaddr string, background *util.WorkerPool) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() getConferenceMutex.Lock() defer getConferenceMutex.Unlock() @@ -904,10 +887,9 @@ func (c *Conference) Delete(ctx context.Context, comm *Community, u *User, ipadd return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true if refCount == 0 { // kick the conference out of the cache @@ -1154,13 +1136,8 @@ func AmSetConferenceProperty(ctx context.Context, confid int32, ndx int32, val * // AmReorderConferences reorders two conferences by sequence number. func AmReorderConferences(ctx context.Context, cid int32, seq1, seq2 int16) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() _, err := tx.ExecContext(ctx, "UPDATE commtoconf SET sequence = -1 WHERE commid = ? AND sequence = ?", cid, seq1) if err == nil { _, err = tx.ExecContext(ctx, "UPDATE commtoconf SET sequence = ? WHERE commid = ? AND sequence = ?", seq1, cid, seq2) @@ -1171,10 +1148,9 @@ func AmReorderConferences(ctx context.Context, cid int32, seq1, seq2 int16) erro if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } @@ -1214,13 +1190,8 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr newConf.CreateLevel = AmDefaultRole("Conference.Create.Public").Level() } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() getConferenceMutex.Lock() defer getConferenceMutex.Unlock() @@ -1279,10 +1250,9 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr return nil, err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true // Add the new conference to the cache, and create our audit record. conferenceCache.Add(rc[0].ConfId, &(rc[0])) diff --git a/database/hotlist.go b/database/hotlist.go index d92dc25..d188f6f 100644 --- a/database/hotlist.go +++ b/database/hotlist.go @@ -49,35 +49,24 @@ func AmCopyConferenceHotlist(ctx context.Context, from, to *User) error { return err } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() for _, hl := range hotlist { if _, err = tx.ExecContext(ctx, "INSERT INTO confhotlist (uid, sequence, commid, confid) VALUES (?, ?, ?, ?)", to.Uid, hl.Sequence, hl.CommId, hl.ConfId); err != nil { return err } } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } // AmReorderHotlist exchanges the position of two items on the user's hotlist. func AmReorderHotlist(ctx context.Context, u *User, seq1, seq2 int16) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() _, err := tx.ExecContext(ctx, "UPDATE confhotlist SET sequence = -1 WHERE uid = ? AND sequence = ?", u.Uid, seq1) if err == nil { @@ -89,22 +78,16 @@ func AmReorderHotlist(ctx context.Context, u *User, seq1, seq2 int16) error { if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } // AmRemoveEntryFromHotlist removes an entry from the user's hotlist. func AmRemoveEntryFromHotlist(ctx context.Context, u *User, seq int16) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() _, err := tx.ExecContext(ctx, "DELETE FROM confhotlist WHERE uid = ? AND sequence = ?", u.Uid, seq) if err == nil { @@ -113,22 +96,16 @@ func AmRemoveEntryFromHotlist(ctx context.Context, u *User, seq int16) error { if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } // AmAppendToHotlist adds a community/conference ID to the end of the user's hotlist. func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() var newseq int16 row := tx.QueryRowContext(ctx, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) @@ -150,10 +127,9 @@ func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } diff --git a/database/post.go b/database/post.go index ec01c2f..ae2f0be 100644 --- a/database/post.go +++ b/database/post.go @@ -294,13 +294,8 @@ func (p *PostHeader) Scribble(ctx context.Context, u *User, comm *Community, ipa return errors.New("cannot scribble an already-scribbled post") } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Scribble on the post header. scribblePseud := "(Scribbled)" // FUTURE: configurable option _, err := tx.ExecContext(ctx, "UPDATE posts SET linecount = 0, hidden = 0, scribble_uid = ?, scribble_date = NOW(), pseud = ? WHERE postid = ?", u.Uid, scribblePseud, p.PostId) @@ -327,10 +322,9 @@ func (p *PostHeader) Scribble(ctx context.Context, u *User, comm *Community, ipa return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true // Patch fields in the post header var newLines int32 = 0 @@ -348,13 +342,8 @@ func (p *PostHeader) Scribble(ctx context.Context, u *User, comm *Community, ipa // Nuke causes a post to be nuked (deleted entirely from the topic). func (p *PostHeader) Nuke(ctx context.Context, u *User, comm *Community, ipaddr string) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Delete all the references to this post. _, err := tx.ExecContext(ctx, "DELETE FROM posts WHERE postid = ?", p.PostId) @@ -394,10 +383,9 @@ func (p *PostHeader) Nuke(ctx context.Context, u *User, comm *Community, ipaddr return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true AmStoreAudit(AmNewCommAudit(AuditConferenceNukeMessage, u.Uid, comm.Id, ipaddr, fmt.Sprintf("post=%d", p.PostId))) return nil } @@ -408,13 +396,8 @@ func (p *PostHeader) Publish(ctx context.Context, comm *Community, publisher *Us return errors.New("cannot publish scribbled post") } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Check if we were already published. row := tx.QueryRowContext(ctx, "SELECT by_uid FROM postpublish WHERE postid = ?", p.PostId) @@ -431,10 +414,9 @@ func (p *PostHeader) Publish(ctx context.Context, comm *Community, publisher *Us comm.Id, p.PostId, publisher.Uid); err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true AmStoreAudit(AmNewAudit(AuditPublishToFrontPage, publisher.Uid, ipaddr, fmt.Sprintf("comm=%d,post=%d", comm.Id, p.PostId))) return nil } @@ -463,13 +445,8 @@ func (p *PostHeader) MoveTo(ctx context.Context, target *Topic, u *User, comm *C return err } - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Adjust post record in the database to make it part of the new topic. _, err = tx.ExecContext(ctx, "UPDATE posts SET parent = 0, topicid = ?, num = ? WHERE postid = ?", target.TopicId, target.TopMessage+1, p.PostId) @@ -511,10 +488,9 @@ func (p *PostHeader) MoveTo(ctx context.Context, target *Topic, u *User, comm *C return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true // Now patch the data structures we have. p.Parent = 0 @@ -589,13 +565,8 @@ func AmGetPostRange(ctx context.Context, topic *Topic, first, last int32) ([]*Po */ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User, pseud string, post string, postLines int32, comm *Community, ipaddr string) (*PostHeader, error) { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Add the post header information. rs, err := tx.ExecContext(ctx, "INSERT INTO posts (topicid, num, linecount, creator_uid, posted, pseud) VALUES (?, ?, ?, ?, NOW(), ?)", @@ -644,10 +615,9 @@ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User, return nil, err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true // create audit record AmStoreAudit(AmNewCommAudit(AuditConferencePostMessage, user.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), diff --git a/database/sidebox.go b/database/sidebox.go index 5e962d6..73654ab 100644 --- a/database/sidebox.go +++ b/database/sidebox.go @@ -66,13 +66,8 @@ func AmGetSideboxes(ctx context.Context, uid int32) ([]*Sidebox, error) { // AmReorderSideboxes changes the position of two sideboxes on the user's list. func AmReorderSideboxes(ctx context.Context, uid int32, seq1, seq2 int32) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() _, err := tx.ExecContext(ctx, "UPDATE sideboxes SET sequence = -1 WHERE uid = ? AND sequence = ?", uid, seq1) if err == nil { @@ -84,22 +79,16 @@ func AmReorderSideboxes(ctx context.Context, uid int32, seq1, seq2 int32) error if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } // AmRemoveSidebox removes a sidebox from the user configuration. func AmRemoveSidebox(ctx context.Context, uid int32, boxid int32) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Get the old sequence number. row := tx.QueryRowContext(ctx, "SELECT sequence FROM sideboxes WHERE uid = ? AND boxid = ?", uid, boxid) @@ -118,22 +107,16 @@ func AmRemoveSidebox(ctx context.Context, uid int32, boxid int32) error { if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } // AmAppendSidebox appends a new sidebox to the existing user's configuration. func AmAppendSidebox(ctx context.Context, uid int32, boxid int32, param *string) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() row := tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM sideboxes WHERE uid = ?", uid) var topseq int32 @@ -147,9 +130,8 @@ func AmAppendSidebox(ctx context.Context, uid int32, boxid int32, param *string) if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } diff --git a/database/topic.go b/database/topic.go index d6bccfb..c6484c8 100644 --- a/database/topic.go +++ b/database/topic.go @@ -372,13 +372,8 @@ func (t *Topic) GetActiveUserEMailAddrs(ctx context.Context, userSelect, dayLimi // backgroundPurgeTopic removes all posts from a topic that's been deleted. func backgroundPurgeTopic(ctx context.Context, topicid int32) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Get some stats on the posts we have to remove. row := tx.QueryRowContext(ctx, "SELECT MAX(postid) FROM posts WHERE topicid = ?", topicid) @@ -408,10 +403,9 @@ func backgroundPurgeTopic(ctx context.Context, topicid int32) error { if err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true return nil } @@ -429,13 +423,8 @@ func eraseTopicRecords(ctx context.Context, tx *sqlx.Tx, topicid int32) error { // Delete deletes this topic. func (t *Topic) Delete(ctx context.Context, u *User, comm *Community, ipaddr string, background *util.WorkerPool) error { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() conf, err := AmGetConference(ctx, t.ConfId) if err != nil { @@ -450,10 +439,9 @@ func (t *Topic) Delete(ctx context.Context, u *User, comm *Community, ipaddr str if err = conf.TouchUpdate(ctx, tx, time.Now()); err != nil { return err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return err } - success = true // create audit record AmStoreAudit(AmNewCommAudit(AuditConferenceDeleteTopic, u.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), @@ -732,13 +720,8 @@ func AmListTopics(ctx context.Context, confid int32, uid int32, viewOption int, */ func AmNewTopic(ctx context.Context, conf *Conference, user *User, title string, zeroPostPseud string, zeroPost string, zeroPostLines int32, comm *Community, ipaddr string) (*Topic, error) { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Insert the new topic into the database. conf.Mutex.Lock() @@ -800,10 +783,9 @@ func AmNewTopic(ctx context.Context, conf *Conference, user *User, title string, return nil, err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true // create audit record AmStoreAudit(AmNewCommAudit(AuditConferenceCreateTopic, user.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), diff --git a/database/user.go b/database/user.go index 5a930c0..a566246 100644 --- a/database/user.go +++ b/database/user.go @@ -593,13 +593,8 @@ func touchUser(ctx context.Context, tx *sqlx.Tx, user *User) { */ func AmAuthenticateUser(ctx context.Context, name string, password string, remoteIP string) (*User, error) { log.Debugf("AmAuthenticateUser() authenticating user %s...", name) - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() user, err := AmGetUserByName(ctx, name, tx) if err != nil { @@ -631,10 +626,9 @@ func AmAuthenticateUser(ctx context.Context, name string, password string, remot } log.Debug("...authenticated") touchUser(ctx, tx, user) - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true AmStoreAudit(AmNewAudit(AuditLoginOK, user.Uid, remoteIP)) return user, nil } @@ -673,13 +667,8 @@ func crackAuthString(authString string) (int32, string, error) { * Standard Go error status. */ func AmAuthenticateUserByToken(ctx context.Context, authString string, remoteIP string) (*User, error) { - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() uid, token, err := crackAuthString(authString) if err != nil { @@ -710,10 +699,9 @@ func AmAuthenticateUserByToken(ctx context.Context, authString string, remoteIP } log.Debug("...authenticated") touchUser(ctx, tx, user) - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true AmStoreAudit(AmNewAudit(AuditLoginOK, user.Uid, remoteIP)) return user, nil } @@ -732,13 +720,8 @@ func AmAuthenticateUserByToken(ctx context.Context, authString string, remoteIP */ func AmCreateNewUser(ctx context.Context, username string, password string, reminder string, dob *time.Time, remoteIP string) (*User, error) { anon, _ := AmGetAnonUser(ctx) - success := false - tx := amdb.MustBegin() - defer func() { - if !success { - tx.Rollback() - } - }() + tx, commit, rollback := transaction(ctx) + defer rollback() // Test if the user name is already taken. row := tx.QueryRowContext(ctx, "SELECT uid FROM users WHERE username = ?", username) @@ -787,10 +770,9 @@ func AmCreateNewUser(ctx context.Context, username string, password string, remi return nil, err } - if err = tx.Commit(); err != nil { + if err = commit(); err != nil { return nil, err } - success = true // auto-join communities if err = AmAutoJoinCommunities(ctx, user); err != nil {