touched up the database code to use transactions where necessary
This commit is contained in:
+114
-32
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"git.erbosoft.com/amy/amsterdam/util"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/klauspost/lctime"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/text/language"
|
||||
@@ -37,12 +38,12 @@ type UserPrefs struct {
|
||||
|
||||
// ReadLocale reads the locale out of the prefs, adjusting for Go use.
|
||||
func (p *UserPrefs) ReadLocale() string {
|
||||
return strings.Replace(p.LocaleID, "_", "-", -1)
|
||||
return strings.ReplaceAll(p.LocaleID, "_", "-")
|
||||
}
|
||||
|
||||
// WriteLocale writes the locale into the prefs, adjusting for backward compatibility.
|
||||
func (p *UserPrefs) WriteLocale(loc string) {
|
||||
p.LocaleID = strings.Replace(loc, "-", "_", -1)
|
||||
p.LocaleID = strings.ReplaceAll(loc, "-", "_")
|
||||
}
|
||||
|
||||
// Clone duplicates the user preferences.
|
||||
@@ -122,9 +123,9 @@ type User struct {
|
||||
|
||||
// UserProperties represents a property entry for a user.
|
||||
type UserProperties struct {
|
||||
Uid int32 `db:"uid"`
|
||||
Index int32 `db:"ndx"`
|
||||
Data *string `db:"data"`
|
||||
Uid int32 `db:"uid"` // UID of user
|
||||
Index int32 `db:"ndx"` // index of property
|
||||
Data *string `db:"data"` // property data
|
||||
}
|
||||
|
||||
// User property indexes defined.
|
||||
@@ -239,6 +240,13 @@ func (u *User) ConfirmEMailAddress(confnum int32, remoteIP string) error {
|
||||
defer func() {
|
||||
AmStoreAudit(ar)
|
||||
}()
|
||||
success := false
|
||||
tx := amdb.MustBegin()
|
||||
defer func() {
|
||||
if !success {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugf("ConfirmEMailAddress for UID %d", u.Uid)
|
||||
u.Mutex.Lock()
|
||||
@@ -252,14 +260,18 @@ func (u *User) ConfirmEMailAddress(confnum int32, remoteIP string) error {
|
||||
ar = AmNewAudit(AuditVerifyEmailFail, u.Uid, remoteIP, "Invalid confirmation number")
|
||||
return errors.New("confirmation number is incorrect. Please try again")
|
||||
}
|
||||
_, err := amdb.Exec("UPDATE users SET verify_email = 1, base_lvl = ? WHERE uid = ?",
|
||||
_, err := tx.Exec("UPDATE users SET verify_email = 1, base_lvl = ? WHERE uid = ?",
|
||||
AmDefaultRole("Global.AfterVerify").Level(), u.Uid)
|
||||
if err == nil {
|
||||
u.VerifyEMail = true
|
||||
u.BaseLevel = AmDefaultRole("Global.AfterVerify").Level()
|
||||
err = AmAutoJoinCommunities(u)
|
||||
err = AmAutoJoinCommunities(tx, u)
|
||||
if err == nil {
|
||||
ar = AmNewAudit(AuditVerifyEmailOK, u.Uid, remoteIP)
|
||||
err = tx.Commit()
|
||||
if err == nil {
|
||||
success = true
|
||||
ar = AmNewAudit(AuditVerifyEmailOK, u.Uid, remoteIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -399,16 +411,50 @@ func AmGetUser(uid int32) (*User, error) {
|
||||
return rc.(*User), err
|
||||
}
|
||||
|
||||
/* AmGetUserByName returns a reference to the specified user.
|
||||
/* AmGetUserTx returns a reference to the specified user inside a transaction.
|
||||
* Parameters:
|
||||
* name - The username of the user.
|
||||
* tx - The transaction we're in.
|
||||
* uid - The UID of the user.
|
||||
* Returns:
|
||||
* Pointer to User containing user data, or nil
|
||||
* Standard Go error status
|
||||
*/
|
||||
func AmGetUserByName(name string) (*User, error) {
|
||||
func AmGetUserTx(tx *sqlx.Tx, uid int32) (*User, error) {
|
||||
var err error = nil
|
||||
getUserMutex.Lock()
|
||||
defer getUserMutex.Unlock()
|
||||
rc, ok := userCache.Get(uid)
|
||||
if !ok {
|
||||
var dbdata []User
|
||||
err = tx.Select(&dbdata, "SELECT * from users WHERE uid = ?", uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbdata) > 1 {
|
||||
return nil, fmt.Errorf("AmGetUser(%d): too many responses(%d)", uid, len(dbdata))
|
||||
}
|
||||
rc = &(dbdata[0])
|
||||
userCache.Add(uid, rc)
|
||||
}
|
||||
return rc.(*User), err
|
||||
}
|
||||
|
||||
/* AmGetUserByName returns a reference to the specified user.
|
||||
* Parameters:
|
||||
* name - The username of the user.
|
||||
* tx - If this is not nil, use this transaction.
|
||||
* Returns:
|
||||
* Pointer to User containing user data, or nil
|
||||
* Standard Go error status
|
||||
*/
|
||||
func AmGetUserByName(name string, tx *sqlx.Tx) (*User, error) {
|
||||
var dbdata []User
|
||||
err := amdb.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name)
|
||||
var err error
|
||||
if tx != nil {
|
||||
err = tx.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name)
|
||||
} else {
|
||||
err = amdb.Select(&dbdata, "SELECT * FROM users WHERE username = ?", name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -416,12 +462,12 @@ func AmGetUserByName(name string) (*User, error) {
|
||||
return nil, fmt.Errorf("AmGetUserByName(\"%s\"): too many responses(%d)", name, len(dbdata))
|
||||
}
|
||||
getUserMutex.Lock()
|
||||
defer getUserMutex.Unlock()
|
||||
rc, ok := userCache.Get(dbdata[0].Uid)
|
||||
if !ok {
|
||||
rc = &(dbdata[0])
|
||||
userCache.Add(dbdata[0].Uid, rc)
|
||||
}
|
||||
getUserMutex.Unlock()
|
||||
return rc.(*User), nil
|
||||
}
|
||||
|
||||
@@ -485,11 +531,11 @@ func hashPassword(password string) string {
|
||||
}
|
||||
|
||||
// touchUser updates the last access time for the user.
|
||||
func touchUser(user *User) {
|
||||
func touchUser(tx *sqlx.Tx, user *User) {
|
||||
user.Mutex.Lock()
|
||||
defer user.Mutex.Unlock()
|
||||
moment := time.Now().UTC()
|
||||
_, _ = amdb.Exec("UPDATE user SET lastaccess = ? WHERE uid = ?", moment, user.Uid)
|
||||
tx.Exec("UPDATE user SET lastaccess = ? WHERE uid = ?", moment, user.Uid)
|
||||
user.LastAccess = &moment
|
||||
}
|
||||
|
||||
@@ -508,8 +554,15 @@ func AmAuthenticateUser(name string, password string, remoteIP string) (*User, e
|
||||
defer func() {
|
||||
AmStoreAudit(ar)
|
||||
}()
|
||||
success := false
|
||||
tx := amdb.MustBegin()
|
||||
defer func() {
|
||||
if !success {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
user, err := AmGetUserByName(name)
|
||||
user, err := AmGetUserByName(name, tx)
|
||||
if err != nil {
|
||||
log.Error("...user not found")
|
||||
ar = AmNewAudit(AuditLoginFail, 0, remoteIP, fmt.Sprintf("Bad username: %s", name))
|
||||
@@ -532,7 +585,12 @@ func AmAuthenticateUser(name string, password string, remoteIP string) (*User, e
|
||||
return nil, errors.New("the password you have specified is incorrect; please try again")
|
||||
}
|
||||
log.Debug("...authenticated")
|
||||
touchUser(user)
|
||||
touchUser(tx, user)
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
success = true
|
||||
ar = AmNewAudit(AuditLoginOK, user.Uid, remoteIP)
|
||||
return user, nil
|
||||
}
|
||||
@@ -574,13 +632,20 @@ func AmAuthenticateUserByToken(authString string, remoteIP string) (*User, error
|
||||
defer func() {
|
||||
AmStoreAudit(ar)
|
||||
}()
|
||||
success := false
|
||||
tx := amdb.MustBegin()
|
||||
defer func() {
|
||||
if !success {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
uid, token, err := crackAuthString(authString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authString not valid, ignored: %v", err)
|
||||
}
|
||||
var user *User
|
||||
user, err = AmGetUser(uid)
|
||||
user, err = AmGetUserTx(tx, uid)
|
||||
if err != nil {
|
||||
log.Error("...user not found")
|
||||
ar = AmNewAudit(AuditLoginFail, 0, remoteIP, fmt.Sprintf("Bad uid: %d", uid))
|
||||
@@ -603,7 +668,12 @@ func AmAuthenticateUserByToken(authString string, remoteIP string) (*User, error
|
||||
return nil, errors.New("token mismatch")
|
||||
}
|
||||
log.Debug("...authenticated")
|
||||
touchUser(user)
|
||||
touchUser(tx, user)
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
success = true
|
||||
ar = AmNewAudit(AuditLoginOK, user.Uid, remoteIP)
|
||||
return user, nil
|
||||
}
|
||||
@@ -624,17 +694,24 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim
|
||||
defer func() {
|
||||
AmStoreAudit(ar)
|
||||
}()
|
||||
|
||||
anon, _ := getAnonUserID()
|
||||
success := false
|
||||
tx := amdb.MustBegin()
|
||||
defer func() {
|
||||
if !success {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
unlock := true
|
||||
amdb.Exec("LOCK TABLES users WRITE, userprefs WRITE, propuser WRITE, commmember WRITE, sideboxes WRITE, confhotlist WRITE;")
|
||||
tx.Exec("LOCK TABLES users WRITE, userprefs WRITE, propuser WRITE, commmember WRITE, sideboxes WRITE, confhotlist WRITE;")
|
||||
defer func() {
|
||||
if unlock {
|
||||
amdb.Exec("UNLOCK TABLES;")
|
||||
tx.Exec("UNLOCK TABLES;")
|
||||
}
|
||||
}()
|
||||
|
||||
// Test if the user name is already taken.
|
||||
rs, err := amdb.Query("SELECT uid FROM users WHERE username = ?", username)
|
||||
rs, err := tx.Query("SELECT uid FROM users WHERE username = ?", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if rs.Next() {
|
||||
@@ -643,7 +720,7 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim
|
||||
}
|
||||
|
||||
// Insert the user record.
|
||||
_, err2 := amdb.Exec(`INSERT INTO users (username, passhash, verify_email, lockout, email_confnum,
|
||||
_, err2 := tx.Exec(`INSERT INTO users (username, passhash, verify_email, lockout, email_confnum,
|
||||
base_lvl, created, lastaccess, passreminder, description, dob) VALUES (?, ?, 0, 0, ?, ?, NOW(), NOW(), ?, '', ?)`,
|
||||
username, hashPassword(password), util.GenerateRandomConfirmationNumber(), AmDefaultRole("Global.NewUser").Level(),
|
||||
reminder, dob)
|
||||
@@ -651,49 +728,54 @@ func AmCreateNewUser(username string, password string, reminder string, dob *tim
|
||||
return nil, err2
|
||||
}
|
||||
// Read back the user, which also puts it in the cache.
|
||||
user, err3 := AmGetUserByName(username)
|
||||
user, err3 := AmGetUserByName(username, tx)
|
||||
if err3 != nil {
|
||||
return nil, err3
|
||||
}
|
||||
log.Debugf("...created new user \"%s\" with UID %d", username, user.Uid)
|
||||
|
||||
// add user preferences
|
||||
_, err = amdb.Exec("INSERT INTO userprefs (uid) VALUES (?)", user.Uid)
|
||||
_, err = tx.Exec("INSERT INTO userprefs (uid) VALUES (?)", user.Uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// add user properties
|
||||
props := make([]UserProperties, 0)
|
||||
anon, _ := getAnonUserID()
|
||||
err = amdb.Select(&props, "SELECT * FROM propuser WHERE uid = ?", anon)
|
||||
err = tx.Select(&props, "SELECT * FROM propuser WHERE uid = ?", anon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, p := range props {
|
||||
_, err := amdb.Exec("INSERT INTO propuser (uid, ndx, data) VALUES (?, ?, ?)", user.Uid, p.Index, p.Data)
|
||||
_, err := tx.Exec("INSERT INTO propuser (uid, ndx, data) VALUES (?, ?, ?)", user.Uid, p.Index, p.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// add user sideboxes
|
||||
err = copySideboxes(user.Uid, anon)
|
||||
err = copySideboxes(tx, user.Uid, anon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amdb.Exec("UNLOCK TABLES;")
|
||||
tx.Exec("UNLOCK TABLES;")
|
||||
unlock = false
|
||||
|
||||
// auto-join communities
|
||||
err = AmAutoJoinCommunities(user)
|
||||
err = AmAutoJoinCommunities(tx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: copy conference hotlists
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
success = true
|
||||
|
||||
// operation was a success - add an audit record
|
||||
ar = AmNewAudit(AuditAccountCreated, user.Uid, remoteIP)
|
||||
return user, nil
|
||||
|
||||
Reference in New Issue
Block a user