further optimized database calls to replace all single-value QueryRowContext calls with GetContext

This commit is contained in:
2026-02-23 18:21:38 -07:00
parent 4113ba2fca
commit 220e1ecc98
14 changed files with 233 additions and 290 deletions
+6 -8
View File
@@ -11,7 +11,6 @@ package database
import ( import (
"context" "context"
"database/sql"
_ "embed" _ "embed"
"fmt" "fmt"
"time" "time"
@@ -240,14 +239,13 @@ func AmStoreAudit(rec *AuditRecord) {
// AmListAuditRecords lists a section of the audit records. // AmListAuditRecords lists a section of the audit records.
func AmListAuditRecords(ctx context.Context, comm *Community, offset, max int) ([]AuditRecord, int, error) { func AmListAuditRecords(ctx context.Context, comm *Community, offset, max int) ([]AuditRecord, int, error) {
var row *sql.Row var err error
if comm != nil {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM audit WHERE commid = ?", comm.Id)
} else {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM audit")
}
var count int var count int
err := row.Scan(&count) if comm != nil {
err = amdb.GetContext(ctx, &count, "SELECT COUNT(*) FROM audit WHERE commid = ?", comm.Id)
} else {
err = amdb.GetContext(ctx, &count, "SELECT COUNT(*) FROM audit")
}
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
} }
+4 -6
View File
@@ -65,9 +65,8 @@ func loadCategories(ctx context.Context) error {
categoryMutex.Lock() categoryMutex.Lock()
defer categoryMutex.Unlock() defer categoryMutex.Unlock()
if allCategories == nil { if allCategories == nil {
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM refcategory") var ncats int
var ncats int32 if err := amdb.GetContext(ctx, &ncats, "SELECT COUNT(*) FROM refcategory"); err != nil {
if err := row.Scan(&ncats); err != nil {
return err return err
} }
allCategories = make([]Category, 0, ncats) allCategories = make([]Category, 0, ncats)
@@ -226,10 +225,9 @@ func AmSearchCategories(ctx context.Context, oper int, term string, offset int,
queryString.WriteString(" AND hide_search = 0") queryString.WriteString(" AND hide_search = 0")
} }
q := queryString.String() q := queryString.String()
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM refcategory WHERE "+q)
var total int var total int
if err = row.Scan(&total); err != nil { if err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM refcategory WHERE "+q); err != nil {
return nil, total, err return nil, -1, err
} }
if total == 0 { if total == 0 {
return make([]*Category, 0), 0, nil return make([]*Category, 0), 0, nil
+67 -70
View File
@@ -28,6 +28,9 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
) )
// ErrNoCommunity is an error returned for "no community found" errors
var ErrNoCommunity error = errors.New("no such community")
// Community struct contains the high level data for a community. // Community struct contains the high level data for a community.
type Community struct { type Community struct {
Mutex sync.RWMutex Mutex sync.RWMutex
@@ -221,18 +224,17 @@ func (c *Community) Membership(ctx context.Context, u *User) (bool, bool, uint16
// MemberCount returns the number of members in the community. // MemberCount returns the number of members in the community.
func (c *Community) MemberCount(ctx context.Context, hidden bool) (int, error) { func (c *Community) MemberCount(ctx context.Context, hidden bool) (int, error) {
var row *sql.Row
if hidden {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM commmember WHERE commid = ?", c.Id)
} else {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM commmember WHERE commid = ? AND hidden = 0", c.Id)
}
var rc int var rc int
if err := row.Scan(&rc); err == nil { var err error
return rc, nil if hidden {
err = amdb.GetContext(ctx, &rc, "SELECT COUNT(*) FROM commmember WHERE commid = ?", c.Id)
} else { } else {
return -1, err err = amdb.GetContext(ctx, &rc, "SELECT COUNT(*) FROM commmember WHERE commid = ? AND hidden = 0", c.Id)
} }
if err == nil {
return rc, nil
}
return -1, err
} }
/* ListMembers lists or searches for community members matching certain criteria. /* ListMembers lists or searches for community members matching certain criteria.
@@ -294,12 +296,11 @@ func (c *Community) ListMembers(ctx context.Context, field int, oper int, term s
query.WriteString(" AND m.hidden = 0") query.WriteString(" AND m.hidden = 0")
} }
q := query.String() q := query.String()
row := amdb.QueryRowContext(ctx, `SELECT COUNT(*) FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid
AND u.contactid = c.contactid`+q, c.Id)
var total int var total int
var err error var err error
var rs *sql.Rows var rs *sql.Rows
if err = row.Scan(&total); err == nil { if err = amdb.GetContext(ctx, &total, `SELECT COUNT(*) FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid
AND u.contactid = c.contactid`+q, c.Id); err == nil {
if offset > 0 { if offset > 0 {
rs, err = amdb.QueryContext(ctx, `SELECT m.uid FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid rs, err = amdb.QueryContext(ctx, `SELECT m.uid FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid
AND u.contactid = c.contactid`+q+" ORDER BY u.username LIMIT ? OFFSET ?", c.Id, max, offset) AND u.contactid = c.contactid`+q+" ORDER BY u.username LIMIT ? OFFSET ?", c.Id, max, offset)
@@ -519,8 +520,7 @@ func (c *Community) SetProfileData(ctx context.Context, name string, alias strin
c.CreateLevel = create_lvl c.CreateLevel = create_lvl
c.DeleteLevel = delete_lvl c.DeleteLevel = delete_lvl
c.JoinLevel = join_lvl c.JoinLevel = join_lvl
row := amdb.QueryRowContext(ctx, "SELECT lastupdate FROM communities WHERE commid = ?", c.Id) err2 := amdb.GetContext(ctx, &(c.LastUpdate), "SELECT lastupdate FROM communities WHERE commid = ?", c.Id)
err2 := row.Scan(&(c.LastUpdate))
if err2 != nil { if err2 != nil {
log.Errorf("SetProfileData scan error: %v", err2) log.Errorf("SetProfileData scan error: %v", err2)
} }
@@ -545,9 +545,8 @@ func (c *Community) Touch(ctx context.Context) error {
defer c.Mutex.Unlock() defer c.Mutex.Unlock()
_, err := amdb.ExecContext(ctx, "UPDATE communities SET lastaccess = NOW() WHERE commid = ?", c.Id) _, err := amdb.ExecContext(ctx, "UPDATE communities SET lastaccess = NOW() WHERE commid = ?", c.Id)
if err == nil { if err == nil {
row := amdb.QueryRowContext(ctx, "SELECT lastaccess FROM communities WHERE commid = ?", c.Id)
var na time.Time var na time.Time
if err = row.Scan(&na); err == nil { if err = amdb.GetContext(ctx, &na, "SELECT lastaccess FROM communities WHERE commid = ?", c.Id); err == nil {
c.LastAccess = &na c.LastAccess = &na
} }
} }
@@ -603,16 +602,19 @@ func (c *Community) GetMemberEMailAddrs(ctx context.Context) ([]string, error) {
func AmGetCommunity(ctx context.Context, id int32) (*Community, error) { func AmGetCommunity(ctx context.Context, id int32) (*Community, error) {
getCommunityMutex.Lock() getCommunityMutex.Lock()
defer getCommunityMutex.Unlock() defer getCommunityMutex.Unlock()
rc, ok := communityCache.Get(id) if rc, ok := communityCache.Get(id); ok {
if !ok {
var newcomm Community
if err := amdb.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id); err != nil {
return nil, err
}
rc = &newcomm
communityCache.Add(id, rc)
}
return rc.(*Community), nil return rc.(*Community), nil
}
var newcomm Community
err := amdb.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id)
switch err {
case nil:
communityCache.Add(id, &newcomm)
return &newcomm, nil
case sql.ErrNoRows:
return nil, ErrNoCommunity
}
return nil, err
} }
/* AmGetCommunityTx returns a reference to the specified community, in a transaction. /* AmGetCommunityTx returns a reference to the specified community, in a transaction.
@@ -627,16 +629,19 @@ func AmGetCommunity(ctx context.Context, id int32) (*Community, error) {
func AmGetCommunityTx(ctx context.Context, tx *sqlx.Tx, id int32) (*Community, error) { func AmGetCommunityTx(ctx context.Context, tx *sqlx.Tx, id int32) (*Community, error) {
getCommunityMutex.Lock() getCommunityMutex.Lock()
defer getCommunityMutex.Unlock() defer getCommunityMutex.Unlock()
rc, ok := communityCache.Get(id) if rc, ok := communityCache.Get(id); ok {
if !ok {
var newcomm Community
if err := tx.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id); err != nil {
return nil, err
}
rc = &newcomm
communityCache.Add(id, rc)
}
return rc.(*Community), nil return rc.(*Community), nil
}
var newcomm Community
err := tx.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id)
switch err {
case nil:
communityCache.Add(id, &newcomm)
return &newcomm, nil
case sql.ErrNoRows:
return nil, ErrNoCommunity
}
return nil, err
} }
/* AmGetCommunityByAlias returns a reference to the specified community. /* AmGetCommunityByAlias returns a reference to the specified community.
@@ -648,13 +653,11 @@ func AmGetCommunityTx(ctx context.Context, tx *sqlx.Tx, id int32) (*Community, e
* Standard Go error status (nil if community not found) * Standard Go error status (nil if community not found)
*/ */
func AmGetCommunityByAlias(ctx context.Context, alias string) (*Community, error) { func AmGetCommunityByAlias(ctx context.Context, alias string) (*Community, error) {
row := amdb.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias)
var cid int32 var cid int32
err := row.Scan(&cid) if err := amdb.GetContext(ctx, &cid, "SELECT commid FROM communities WHERE alias = ?", alias); err != nil {
if err == nil {
return AmGetCommunity(ctx, cid)
}
return nil, err return nil, err
}
return AmGetCommunity(ctx, cid)
} }
/* AmGetCommunityByAliasTx returns a reference to the specified community, within a transaction. /* AmGetCommunityByAliasTx returns a reference to the specified community, within a transaction.
@@ -667,13 +670,11 @@ func AmGetCommunityByAlias(ctx context.Context, alias string) (*Community, error
* Standard Go error status (nil if community not found) * Standard Go error status (nil if community not found)
*/ */
func AmGetCommunityByAliasTx(ctx context.Context, tx *sqlx.Tx, alias string) (*Community, error) { func AmGetCommunityByAliasTx(ctx context.Context, tx *sqlx.Tx, alias string) (*Community, error) {
row := tx.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias)
var cid int32 var cid int32
err := row.Scan(&cid) if err := tx.GetContext(ctx, &cid, "SELECT commid FROM communities WHERE alias = ?", alias); err != nil {
if err == nil {
return AmGetCommunityTx(ctx, tx, cid)
}
return nil, err return nil, err
}
return AmGetCommunityTx(ctx, tx, cid)
} }
/* AmGetCommunityFromParam returns a reference to the specified community based on the parameter. /* AmGetCommunityFromParam returns a reference to the specified community based on the parameter.
@@ -793,20 +794,22 @@ func AmAutoJoinCommunities(ctx context.Context, user *User) error {
// internalGetCommProp is a helper used by the community property functions. // internalGetCommProp is a helper used by the community property functions.
func internalGetCommProp(ctx context.Context, cid int32, ndx int32) (*CommunityProperties, error) { func internalGetCommProp(ctx context.Context, cid int32, ndx int32) (*CommunityProperties, error) {
var err error = nil
key := fmt.Sprintf("%d:%d", cid, ndx) key := fmt.Sprintf("%d:%d", cid, ndx)
getCommunityPropMutex.Lock() getCommunityPropMutex.Lock()
defer getCommunityPropMutex.Unlock() defer getCommunityPropMutex.Unlock()
rc, ok := communityPropCache.Get(key) if rc, ok := communityPropCache.Get(key); ok {
if !ok {
var prop CommunityProperties
if err = amdb.GetContext(ctx, &prop, "SELECT * from propcomm WHERE cid = ? AND ndx = ?", cid, ndx); err != nil {
return nil, err
}
rc = &prop
communityPropCache.Add(key, rc)
}
return rc.(*CommunityProperties), nil return rc.(*CommunityProperties), nil
}
var prop CommunityProperties
err := amdb.GetContext(ctx, &prop, "SELECT * from propcomm WHERE cid = ? AND ndx = ?", cid, ndx)
switch err {
case nil:
communityPropCache.Add(key, &prop)
return &prop, nil
case sql.ErrNoRows:
return nil, nil
}
return nil, err
} }
/* AmGetCommunityProperty retrieves the value of a community property. /* AmGetCommunityProperty retrieves the value of a community property.
@@ -882,10 +885,8 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i
defer rollback() defer rollback()
// validate alias does not already exist // validate alias does not already exist
row := tx.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias)
var tmpcid int32 var tmpcid int32
err := row.Scan(&tmpcid) if err := tx.GetContext(ctx, &tmpcid, "SELECT commid FROM communities WHERE alias = ?", alias); err != sql.ErrNoRows {
if err != sql.ErrNoRows {
if err == nil { if err == nil {
err = errors.New("a community with that alias already exists") err = errors.New("a community with that alias already exists")
} }
@@ -893,7 +894,7 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i
} }
// establish the community record // establish the community record
_, err = tx.ExecContext(ctx, `INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl, _, err := tx.ExecContext(ctx, `INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl,
create_lvl, delete_lvl, join_lvl, host_uid, hide_dir, hide_search, commname, language, create_lvl, delete_lvl, join_lvl, host_uid, hide_dir, hide_search, commname, language,
synopsis, rules, joinkey, alias) VALUES (NOW(), NOW(), NOW(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, synopsis, rules, joinkey, alias) VALUES (NOW(), NOW(), NOW(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
AmRoleList("Community.Read").Default().Level(), AmRoleList("Community.Write").Default().Level(), AmRoleList("Community.Read").Default().Level(), AmRoleList("Community.Write").Default().Level(),
@@ -950,14 +951,12 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i
*/ */
func AmGetCommunitiesForCategory(ctx context.Context, catid int32, offset int, max int, showAll bool) ([]*Community, int, error) { func AmGetCommunitiesForCategory(ctx context.Context, catid int32, offset int, max int, showAll bool) ([]*Community, int, error) {
var err error var err error
var row *sql.Row
if showAll {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities WHERE catid = ?", catid)
} else {
row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities WHERE catid = ? AND hide_dir = 0", catid)
}
var total int var total int
err = row.Scan(&total) if showAll {
err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities WHERE catid = ?", catid)
} else {
err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities WHERE catid = ? AND hide_dir = 0", catid)
}
if err != nil || total == 0 { if err != nil || total == 0 {
return make([]*Community, 0), 0, err // short-circuit return return make([]*Community, 0), 0, err // short-circuit return
} }
@@ -1043,9 +1042,8 @@ func AmSearchCommunities(ctx context.Context, field int, oper int, term string,
queryPortion.WriteString(" AND hide_search = 0") queryPortion.WriteString(" AND hide_search = 0")
} }
q := queryPortion.String() q := queryPortion.String()
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities "+q)
var total int var total int
err := row.Scan(&total) err := amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities "+q)
if err != nil || total == 0 { if err != nil || total == 0 {
return make([]*Community, 0), 0, err // short-circuit return return make([]*Community, 0), 0, err // short-circuit return
} }
@@ -1123,10 +1121,9 @@ func AmSearchCommunityMembers(ctx context.Context, c *Community, field int, oper
return nil, -1, errors.New("invalid operator selector") return nil, -1, errors.New("invalid operator selector")
} }
q := queryPortion.String() q := queryPortion.String()
row := amdb.QueryRowContext(ctx, `SELECT COUNT(*) FROM users u, contacts c, commmember m WHERE u.contactid = c.contactid AND u.uid = m.uid
AND m.commid = ? AND u.is_anon = 0 AND `+q, c.Id)
var total int var total int
err := row.Scan(&total) err := amdb.GetContext(ctx, &total, `SELECT COUNT(*) FROM users u, contacts c, commmember m WHERE u.contactid = c.contactid AND u.uid = m.uid
AND m.commid = ? AND u.is_anon = 0 AND `+q, c.Id)
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
} }
+54 -74
View File
@@ -26,6 +26,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// ErrNoConference is an error thrown when a conference is not found.
var ErrNoConference error = errors.New("no such conference")
// Conference struct is the top-level structure for a conference. // Conference struct is the top-level structure for a conference.
type Conference struct { type Conference struct {
Mutex sync.Mutex Mutex sync.Mutex
@@ -172,17 +175,14 @@ func (c *Conference) AliasesQ(ctx context.Context) []string {
// AddAlias adds an alias to the conference. // AddAlias adds an alias to the conference.
func (c *Conference) AddAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error { func (c *Conference) AddAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error {
row := amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias)
tmp := "" tmp := ""
err := row.Scan(&tmp) if err := amdb.GetContext(ctx, &tmp, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias); err != sql.ErrNoRows {
if err != sql.ErrNoRows {
if err == nil { if err == nil {
return fmt.Errorf("the alias '%s' is already in use by another conference", alias) return fmt.Errorf("the alias '%s' is already in use by another conference", alias)
} }
return err return err
} }
_, err = amdb.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", c.ConfId, alias) if _, err := amdb.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", c.ConfId, alias); err != nil {
if err != nil {
return err return err
} }
@@ -192,17 +192,14 @@ func (c *Conference) AddAlias(ctx context.Context, alias string, u *User, comm *
// RemoveAlias removes an alias from the conference. // RemoveAlias removes an alias from the conference.
func (c *Conference) RemoveAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error { func (c *Conference) RemoveAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error {
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM confalias WHERE confid = ?", c.ConfId)
aliasCount := 0 aliasCount := 0
err := row.Scan(&aliasCount) if err := amdb.GetContext(ctx, &aliasCount, "SELECT COUNT(*) FROM confalias WHERE confid = ?", c.ConfId); err != nil {
if err != nil {
return err return err
} }
if aliasCount == 1 { if aliasCount == 1 {
row = amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias)
tmp := "" tmp := ""
err = row.Scan(&tmp) err := amdb.GetContext(ctx, &tmp, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias)
if err == nil { if err == nil {
return errors.New("the conference must have at least one alias") return errors.New("the conference must have at least one alias")
} else if err != sql.ErrNoRows { } else if err != sql.ErrNoRows {
@@ -254,9 +251,8 @@ func (c *Conference) HostsQ(ctx context.Context) []*User {
// InCommunity returns true if the specified conference is in the community. // InCommunity returns true if the specified conference is in the community.
func (c *Conference) InCommunity(ctx context.Context, comm *Community) (bool, error) { func (c *Conference) InCommunity(ctx context.Context, comm *Community) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT commid FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId)
var tmp int32 var tmp int32
err := row.Scan(&tmp) err := amdb.GetContext(ctx, &tmp, "SELECT commid FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId)
switch err { switch err {
case nil: case nil:
return true, nil return true, nil
@@ -268,9 +264,8 @@ func (c *Conference) InCommunity(ctx context.Context, comm *Community) (bool, er
// HiddenInList returns whether or not this conference is hidden in the community's conference list. // HiddenInList returns whether or not this conference is hidden in the community's conference list.
func (c *Conference) HiddenInList(ctx context.Context, comm *Community) (bool, error) { func (c *Conference) HiddenInList(ctx context.Context, comm *Community) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT hide_list FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId)
var rc bool var rc bool
err := row.Scan(&rc) err := amdb.GetContext(ctx, &rc, "SELECT hide_list FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId)
switch err { switch err {
case nil: case nil:
return rc, nil return rc, nil
@@ -317,9 +312,8 @@ func (c *Conference) Members(ctx context.Context) ([]ConferenceMember, error) {
// Membership returns a membership flag and granted level for the user in this conference. // Membership returns a membership flag and granted level for the user in this conference.
func (c *Conference) Membership(ctx context.Context, u *User) (bool, uint16, error) { func (c *Conference) Membership(ctx context.Context, u *User) (bool, uint16, error) {
row := amdb.QueryRowContext(ctx, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid)
var level uint16 var level uint16
err := row.Scan(&level) err := amdb.GetContext(ctx, &level, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid)
switch err { switch err {
case nil: case nil:
return true, level, nil return true, level, nil
@@ -335,9 +329,8 @@ func (c *Conference) SetMembership(ctx context.Context, u *User, level uint16, b
_, err := amdb.ExecContext(ctx, "DELETE FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) _, err := amdb.ExecContext(ctx, "DELETE FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid)
return err return err
} }
row := amdb.QueryRowContext(ctx, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid)
var oldLevel uint16 var oldLevel uint16
err := row.Scan(&oldLevel) err := amdb.GetContext(ctx, &oldLevel, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid)
switch err { switch err {
case nil: case nil:
if oldLevel == level { if oldLevel == level {
@@ -464,8 +457,7 @@ func (c *Conference) SetInfo(ctx context.Context, name, descr string, read_lvl,
create_lvl, hide_lvl, nuke_lvl, change_lvl, delete_lvl, c.ConfId) create_lvl, hide_lvl, nuke_lvl, change_lvl, delete_lvl, c.ConfId)
if err == nil { if err == nil {
var tmp Conference var tmp Conference
err := amdb.GetContext(ctx, &tmp, "SELECT * FROM confs WHERE confid = ?", c.ConfId) if err = amdb.GetContext(ctx, &tmp, "SELECT * FROM confs WHERE confid = ?", c.ConfId); err == nil {
if err == nil {
if c.Name != tmp.Name { if c.Name != tmp.Name {
AmStoreAudit(AmNewCommAudit(AuditConferenceName, u.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", c.ConfId), fmt.Sprintf("name='%s'", tmp.Name))) AmStoreAudit(AmNewCommAudit(AuditConferenceName, u.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", c.ConfId), fmt.Sprintf("name='%s'", tmp.Name)))
} }
@@ -576,11 +568,10 @@ func (c *Conference) TouchPost(ctx context.Context, tx *sqlx.Tx, u *User, lastPo
// UnreadMessages returns the total number of unread messages in a conference for a user. // UnreadMessages returns the total number of unread messages in a conference for a user.
func (c *Conference) UnreadMessages(ctx context.Context, u *User) (int32, error) { func (c *Conference) UnreadMessages(ctx context.Context, u *User) (int32, error) {
row := amdb.QueryRowContext(ctx, `SELECT SUM(t.top_message - IFNULL(s.last_message,-1)) var rc int32
err := amdb.GetContext(ctx, &rc, `SELECT SUM(t.top_message - IFNULL(s.last_message,-1))
FROM topics t LEFT JOIN topicsettings s ON t.topicid = s.topicid AND s.uid = ? FROM topics t LEFT JOIN topicsettings s ON t.topicid = s.topicid AND s.uid = ?
WHERE t.confid = ? AND t.archived = 0 AND (s.hidden IS NULL OR s.hidden = 0)`, u.Uid, c.ConfId) WHERE t.confid = ? AND t.archived = 0 AND (s.hidden IS NULL OR s.hidden = 0)`, u.Uid, c.ConfId)
var rc int32
err := row.Scan(&rc)
return rc, err return rc, err
} }
@@ -600,10 +591,8 @@ func (c *Conference) Fixseen(ctx context.Context, u *User) error {
defer rollback() defer rollback()
// Get a count of topics beforehand. // Get a count of topics beforehand.
row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM topics WHERE confid = ?", c.ConfId)
count := 0 count := 0
err := row.Scan(&count) if err := tx.GetContext(ctx, &count, "SELECT COUNT(*) FROM topics WHERE confid = ?", c.ConfId); err != nil {
if err != nil {
return err return err
} }
@@ -661,9 +650,8 @@ func (c *Conference) GetCustomBlocks(ctx context.Context) (string, string, error
func (c *Conference) SetCustomBlocks(ctx context.Context, topBlock, bottomBlock string) error { func (c *Conference) SetCustomBlocks(ctx context.Context, topBlock, bottomBlock string) error {
tx, commit, rollback := transaction(ctx) tx, commit, rollback := transaction(ctx)
defer rollback() defer rollback()
row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM confcustom WHERE confid = ?", c.ConfId)
ct := 0 ct := 0
err := row.Scan(&ct) err := tx.GetContext(ctx, &ct, "SELECT COUNT(*) FROM confcustom WHERE confid = ?", c.ConfId)
if err != nil { if err != nil {
return err return err
} }
@@ -872,18 +860,17 @@ func (c *Conference) Delete(ctx context.Context, comm *Community, u *User, ipadd
defer getConferenceMutex.Unlock() defer getConferenceMutex.Unlock()
// any references to conference other than this community? // any references to conference other than this community?
row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", c.ConfId, comm.Id)
refCount := 0 refCount := 0
err := row.Scan(&refCount) if err := tx.GetContext(ctx, &refCount, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", c.ConfId, comm.Id); err != nil {
if err != nil {
return err return err
} }
// break the link with the community // break the link with the community
if _, err = tx.ExecContext(ctx, "DELETE FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId); err != nil { if _, err := tx.ExecContext(ctx, "DELETE FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId); err != nil {
return err return err
} }
var err error
if refCount == 0 { if refCount == 0 {
// We have to delete all the conference core data now. // We have to delete all the conference core data now.
_, err = tx.ExecContext(ctx, "DELETE FROM confs WHERE confid = ?", c.ConfId) _, err = tx.ExecContext(ctx, "DELETE FROM confs WHERE confid = ?", c.ConfId)
@@ -943,9 +930,8 @@ func (*conferenceServiceVTable) OnDeleteCommunity(ctx context.Context, tx *sqlx.
} }
for i, confid := range confids { for i, confid := range confids {
// any references to conference other than this community? // any references to conference other than this community?
row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", confid, commid)
refCount := 0 refCount := 0
err = row.Scan(&refCount) err := tx.GetContext(ctx, &refCount, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", confid, commid)
if err != nil { if err != nil {
return err return err
} }
@@ -1011,19 +997,21 @@ func (*conferenceServiceVTable) OnUserLeaveCommunity(context.Context, *sqlx.Tx,
* Standard Go error status. * Standard Go error status.
*/ */
func AmGetConference(ctx context.Context, id int32) (*Conference, error) { func AmGetConference(ctx context.Context, id int32) (*Conference, error) {
var err error = nil
getConferenceMutex.Lock() getConferenceMutex.Lock()
defer getConferenceMutex.Unlock() defer getConferenceMutex.Unlock()
rc, ok := conferenceCache.Get(id) if rc, ok := conferenceCache.Get(id); ok {
if !ok { return rc.(*Conference), nil
}
var conf Conference var conf Conference
if err = amdb.GetContext(ctx, &conf, "SELECT * from confs where confid = ?"); err != nil { err := amdb.GetContext(ctx, &conf, "SELECT * from confs where confid = ?")
switch err {
case nil:
conferenceCache.Add(id, &conf)
return &conf, nil
case sql.ErrNoRows:
return nil, ErrNoConference
}
return nil, err return nil, err
}
rc = &conf
conferenceCache.Add(id, rc)
}
return rc.(*Conference), err
} }
/* AmGetConferenceByAlias returns a conference given its alias. /* AmGetConferenceByAlias returns a conference given its alias.
@@ -1040,8 +1028,7 @@ func AmGetConferenceByAlias(ctx context.Context, alias string) (*Conference, err
if ok { if ok {
confid = xconf.(int32) confid = xconf.(int32)
} else { } else {
row := amdb.QueryRowContext(ctx, "SELECT confid FROM confalias WHERE alias = ?", alias) err := amdb.GetContext(ctx, &confid, "SELECT confid FROM confalias WHERE alias = ?", alias)
err := row.Scan(&confid)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("alias not found: %s", alias) return nil, fmt.Errorf("alias not found: %s", alias)
} else if err != nil { } else if err != nil {
@@ -1061,9 +1048,8 @@ func AmGetConferenceByAlias(ctx context.Context, alias string) (*Conference, err
* Standard Go error status. * Standard Go error status.
*/ */
func AmGetConferenceContainingPost(ctx context.Context, postId int64) (*Conference, error) { func AmGetConferenceContainingPost(ctx context.Context, postId int64) (*Conference, error) {
row := amdb.QueryRowContext(ctx, "SELECT t.confid FROM topics t, posts p WHERE p.postid = ? AND p.topicid = t.topicid", postId)
var confId int32 var confId int32
err := row.Scan(&confId) err := amdb.GetContext(ctx, &confId, "SELECT t.confid FROM topics t, posts p WHERE p.postid = ? AND p.topicid = t.topicid", postId)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("post not found: %d", postId) return nil, fmt.Errorf("post not found: %d", postId)
} else if err != nil { } else if err != nil {
@@ -1082,10 +1068,9 @@ func AmGetConferenceContainingPost(ctx context.Context, postId int64) (*Conferen
* Standard Go error status. * Standard Go error status.
*/ */
func AmGetConferenceByAliasInCommunity(ctx context.Context, cid int32, alias string) (*Conference, error) { func AmGetConferenceByAliasInCommunity(ctx context.Context, cid int32, alias string) (*Conference, error) {
row := amdb.QueryRowContext(ctx, `SELECT c.confid FROM commtoconf c, confalias a WHERE c.confid = a.confid
AND c.commid = ? AND a.alias = ?`, cid, alias)
var confid int32 var confid int32
err := row.Scan(&confid) err := amdb.GetContext(ctx, &confid, `SELECT c.confid FROM commtoconf c, confalias a WHERE c.confid = a.confid
AND c.commid = ? AND a.alias = ?`, cid, alias)
switch err { switch err {
case nil: case nil:
return AmGetConference(ctx, confid) return AmGetConference(ctx, confid)
@@ -1124,8 +1109,7 @@ func AmListConferences(ctx context.Context, cid int32, showHidden bool) ([]*Conf
} }
} }
for i := range rc { for i := range rc {
row := amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ?", rc[i].ConfId) err := amdb.GetContext(ctx, &(rc[i].Alias), "SELECT alias FROM confalias WHERE confid = ?", rc[i].ConfId)
err = row.Scan(&(rc[i].Alias))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1140,20 +1124,22 @@ func AmListConferences(ctx context.Context, cid int32, showHidden bool) ([]*Conf
// internalGetConfProp is a helper used by the conference property functions. // internalGetConfProp is a helper used by the conference property functions.
func internalGetConfProp(ctx context.Context, confid int32, ndx int32) (*ConferenceProperties, error) { func internalGetConfProp(ctx context.Context, confid int32, ndx int32) (*ConferenceProperties, error) {
var err error = nil
key := fmt.Sprintf("%d:%d", confid, ndx) key := fmt.Sprintf("%d:%d", confid, ndx)
getConferencePropMutex.Lock() getConferencePropMutex.Lock()
defer getConferencePropMutex.Unlock() defer getConferencePropMutex.Unlock()
rc, ok := conferencePropCache.Get(key) if rc, ok := conferencePropCache.Get(key); ok {
if !ok {
var prop ConferenceProperties
if err = amdb.GetContext(ctx, &prop, "SELECT * from propconf WHERE confid = ? AND ndx = ?", confid, ndx); err != nil {
return nil, err
}
rc = &prop
conferencePropCache.Add(key, rc)
}
return rc.(*ConferenceProperties), nil return rc.(*ConferenceProperties), nil
}
var prop ConferenceProperties
err := amdb.GetContext(ctx, &prop, "SELECT * from propconf WHERE confid = ? AND ndx = ?", confid, ndx)
switch err {
case nil:
conferencePropCache.Add(key, &prop)
return &prop, nil
case sql.ErrNoRows:
return nil, nil
}
return nil, err
} }
/* AmGetConferenceProperty retrieves the value of a conference property. /* AmGetConferenceProperty retrieves the value of a conference property.
@@ -1268,9 +1254,8 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr
defer getConferenceMutex.Unlock() defer getConferenceMutex.Unlock()
// Ensure the alias is not in use. // Ensure the alias is not in use.
row := tx.QueryRowContext(ctx, "SELECT confid FROM confalias WHERE alias = ?", alias)
var tmp int32 var tmp int32
err := row.Scan(&tmp) err := tx.GetContext(ctx, &tmp, "SELECT confid FROM confalias WHERE alias = ?", alias)
if err == nil { if err == nil {
return nil, fmt.Errorf("the alias '%s' is already in use by a different conference", alias) return nil, fmt.Errorf("the alias '%s' is already in use by a different conference", alias)
} else if err != sql.ErrNoRows { } else if err != sql.ErrNoRows {
@@ -1294,29 +1279,24 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr
} }
// Attach the alias to the conference. // Attach the alias to the conference.
_, err = tx.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", rc.ConfId, alias) if _, err = tx.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", rc.ConfId, alias); err != nil {
if err != nil {
return nil, err return nil, err
} }
// Get the current "last" sequence number. // Get the current "last" sequence number.
row = tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM commtoconf WHERE commid = ?", comm.Id)
var seq int var seq int
err = row.Scan(&seq) if err = tx.GetContext(ctx, &seq, "SELECT MAX(sequence) FROM commtoconf WHERE commid = ?", comm.Id); err != nil {
if err != nil {
return nil, err return nil, err
} }
// Link the conference into the community, and set the hide flag. // Link the conference into the community, and set the hide flag.
_, err = tx.ExecContext(ctx, "INSERT INTO commtoconf (commid, confid, sequence, hide_list) VALUES (?, ?, ?, ?)", comm.Id, rc.ConfId, if _, err = tx.ExecContext(ctx, "INSERT INTO commtoconf (commid, confid, sequence, hide_list) VALUES (?, ?, ?, ?)", comm.Id, rc.ConfId,
int16(seq+COMMTOCONF_SEQ_SPACING), hide_list) int16(seq+COMMTOCONF_SEQ_SPACING), hide_list); err != nil {
if err != nil {
return nil, err return nil, err
} }
// Make the specified user the first host of the conference. // Make the specified user the first host of the conference.
_, err = tx.ExecContext(ctx, "INSERT INTO confmember (confid, uid, granted_lvl) VALUES (?, ?, ?)", rc.ConfId, u.Uid, AmDefaultRole("Conference.NewHost").Level()) if _, err = tx.ExecContext(ctx, "INSERT INTO confmember (confid, uid, granted_lvl) VALUES (?, ?, ?)", rc.ConfId, u.Uid, AmDefaultRole("Conference.NewHost").Level()); err != nil {
if err != nil {
return nil, err return nil, err
} }
+9 -16
View File
@@ -56,8 +56,7 @@ type ContactInfo struct {
// lookupCommunityContact looks up the ID of a contact for a community. // lookupCommunityContact looks up the ID of a contact for a community.
func lookupCommunityContact(ctx context.Context, id int32) (int32, error) { func lookupCommunityContact(ctx context.Context, id int32) (int32, error) {
var rc int32 = -1 var rc int32 = -1
row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_commid = ?", id) err := amdb.GetContext(ctx, &rc, "SELECT contactid FROM contacts WHERE owner_commid = ?", id)
err := row.Scan(&rc)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return -1, nil return -1, nil
} }
@@ -67,8 +66,7 @@ func lookupCommunityContact(ctx context.Context, id int32) (int32, error) {
// lookupUserContact looks up the ID of a contact for a user. // lookupUserContact looks up the ID of a contact for a user.
func lookupUserContact(ctx context.Context, uid int32) (int32, error) { func lookupUserContact(ctx context.Context, uid int32) (int32, error) {
var rc int32 = -1 var rc int32 = -1
row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid) err := amdb.GetContext(ctx, &rc, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid)
err := row.Scan(&rc)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return -1, nil return -1, nil
} }
@@ -150,9 +148,8 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) (
} }
if !emailChange { if !emailChange {
// we don't THINK the E-mail address is changing, but we could be wrong... // we don't THINK the E-mail address is changing, but we could be wrong...
row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE contactid = ? AND email = ?", ci.ContactId, ci.Email)
var tmpcid int32 var tmpcid int32
err := row.Scan(&tmpcid) err := amdb.GetContext(ctx, &tmpcid, "SELECT contactid FROM contacts WHERE contactid = ? AND email = ?", ci.ContactId, ci.Email)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
emailChange = true emailChange = true
} else if err != nil { } else if err != nil {
@@ -185,9 +182,7 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) (
contactCache.Add(ci.ContactId, ci) contactCache.Add(ci.ContactId, ci)
} }
// Refresh the last update date. // Refresh the last update date.
row := amdb.QueryRowContext(ctx, "SELECT lastupdate FROM contacts WHERE contactid = ?", ci.ContactId) if err := amdb.GetContext(ctx, &(ci.LastUpdate), "SELECT lastupdate FROM contacts WHERE contactid = ?", ci.ContactId); err != nil {
err := row.Scan(&(ci.LastUpdate))
if err != nil {
return false, err return false, err
} }
if ci.OwnerCommId < 0 { if ci.OwnerCommId < 0 {
@@ -197,7 +192,7 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) (
} else { } else {
AmStoreAudit(AmNewCommAudit(AuditCommunityContactInfo, changer.Uid, ci.OwnerCommId, ipaddr, fmt.Sprintf("contactid=%d", ci.ContactId))) AmStoreAudit(AmNewCommAudit(AuditCommunityContactInfo, changer.Uid, ci.OwnerCommId, ipaddr, fmt.Sprintf("contactid=%d", ci.ContactId)))
} }
return emailChange, err return emailChange, nil
} }
// Clone makes a copy of the ContactInfo. // Clone makes a copy of the ContactInfo.
@@ -251,11 +246,10 @@ func setupContactsCache() {
// internalContactInfo retrieves the contact info from the database. // internalContactInfo retrieves the contact info from the database.
func internalContactInfo(ctx context.Context, id int32) (*ContactInfo, error) { func internalContactInfo(ctx context.Context, id int32) (*ContactInfo, error) {
var cinf ContactInfo var cinf ContactInfo
err := amdb.GetContext(ctx, &cinf, "SELECT * from contacts WHERE contactid = ?", id) if err := amdb.GetContext(ctx, &cinf, "SELECT * from contacts WHERE contactid = ?", id); err != nil {
if err == nil {
return &cinf, nil
}
return nil, err return nil, err
}
return &cinf, nil
} }
/* AmGetContactInfo retrieves the contact info for a given identifier. /* AmGetContactInfo retrieves the contact info for a given identifier.
@@ -292,9 +286,8 @@ func AmGetContactInfo(ctx context.Context, id int32) (*ContactInfo, error) {
* Standard Go error status. * Standard Go error status.
*/ */
func AmGetContactInfoForUser(ctx context.Context, uid int32) (*ContactInfo, error) { func AmGetContactInfoForUser(ctx context.Context, uid int32) (*ContactInfo, error) {
row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid)
var cid int32 var cid int32
err := row.Scan(&cid) err := amdb.GetContext(ctx, &cid, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid)
switch err { switch err {
case nil: case nil:
return AmGetContactInfo(ctx, cid) return AmGetContactInfo(ctx, cid)
+1 -2
View File
@@ -23,9 +23,8 @@ import (
* Standard Go error status. * Standard Go error status.
*/ */
func AmIsEmailAddressBanned(ctx context.Context, address string) (bool, error) { func AmIsEmailAddressBanned(ctx context.Context, address string) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT by_uid FROM emailban WHERE address = ?", address)
var uid int32 var uid int32
err := row.Scan(&uid) err := amdb.GetContext(ctx, &uid, "SELECT by_uid FROM emailban WHERE address = ?", address)
switch err { switch err {
case nil: case nil:
return true, nil return true, nil
+3 -6
View File
@@ -108,8 +108,7 @@ func AmGlobals(ctx context.Context) (*Globals, error) {
defer globalsMutex.Unlock() defer globalsMutex.Unlock()
if theGlobals == nil { if theGlobals == nil {
var g Globals var g Globals
err := amdb.GetContext(ctx, &g, "SELECT * FROM globals") if err := amdb.GetContext(ctx, &g, "SELECT * FROM globals"); err != nil {
if err != nil {
return nil, err return nil, err
} }
theGlobals = &g theGlobals = &g
@@ -146,8 +145,7 @@ func AmGetGlobalProperty(ctx context.Context, index int32) (string, error) {
var err error = nil var err error = nil
rc, ok := globalProps[index] rc, ok := globalProps[index]
if !ok { if !ok {
row := amdb.QueryRowContext(ctx, "SELECT data FROM propglobal WHERE ndx = ?", index) err := amdb.GetContext(ctx, &rc, "SELECT data FROM propglobal WHERE ndx = ?", index)
err = row.Scan(&rc)
switch err { switch err {
case nil: case nil:
globalProps[index] = rc globalProps[index] = rc
@@ -172,9 +170,8 @@ func AmSetGlobalProperty(ctx context.Context, index int32, value string) error {
defer globalPropMutex.Unlock() defer globalPropMutex.Unlock()
_, updateMode := globalProps[index] _, updateMode := globalProps[index]
if !updateMode { if !updateMode {
row := amdb.QueryRowContext(ctx, "SELECT data FROM propglobal WHERE ndx = ?", index)
var tmpdata string var tmpdata string
err := row.Scan(&tmpdata) err := amdb.GetContext(ctx, &tmpdata, "SELECT data FROM propglobal WHERE ndx = ?", index)
switch err { switch err {
case nil: case nil:
updateMode = true updateMode = true
+3 -6
View File
@@ -108,15 +108,13 @@ func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error
defer rollback() defer rollback()
var newseq int16 var newseq int16
row := tx.QueryRowContext(ctx, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) err := tx.GetContext(ctx, &newseq, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid)
err := row.Scan(&newseq)
if err == nil { if err == nil {
return errors.New("community/conference already exist in hotlist") return errors.New("community/conference already exist in hotlist")
} else if err != sql.ErrNoRows { } else if err != sql.ErrNoRows {
return err return err
} }
row = tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM confhotlist WHERE uid = ?", u.Uid) err = tx.GetContext(ctx, &newseq, "SELECT MAX(sequence) FROM confhotlist WHERE uid = ?", u.Uid)
err = row.Scan(&newseq)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
newseq = 0 newseq = 0
} else if err != nil { } else if err != nil {
@@ -135,9 +133,8 @@ func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error
// AmIsInHotlist returns true if the community/conference pair is in the hotlist. // AmIsInHotlist returns true if the community/conference pair is in the hotlist.
func AmIsInHotlist(ctx context.Context, u *User, commid, confid int32) (bool, error) { func AmIsInHotlist(ctx context.Context, u *User, commid, confid int32) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid)
var tmp int16 var tmp int16
err := row.Scan(&tmp) err := amdb.GetContext(ctx, &tmp, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid)
switch err { switch err {
case nil: case nil:
return true, nil return true, nil
+2 -4
View File
@@ -61,8 +61,7 @@ func (img *ImageStore) Save(ctx context.Context) error {
*/ */
func AmLoadImage(ctx context.Context, id int32) (*ImageStore, error) { func AmLoadImage(ctx context.Context, id int32) (*ImageStore, error) {
var imgdata ImageStore var imgdata ImageStore
err := amdb.GetContext(ctx, &imgdata, "SELECT * FROM imagestore WHERE imgid = ?", id) if err := amdb.GetContext(ctx, &imgdata, "SELECT * FROM imagestore WHERE imgid = ?", id); err != nil {
if err != nil {
return nil, err return nil, err
} }
return &imgdata, nil return &imgdata, nil
@@ -81,9 +80,8 @@ func AmLoadImage(ctx context.Context, id int32) (*ImageStore, error) {
*/ */
func AmStoreImage(ctx context.Context, typecode int16, owner int32, mimetype string, data []byte) (*ImageStore, error) { func AmStoreImage(ctx context.Context, typecode int16, owner int32, mimetype string, data []byte) (*ImageStore, error) {
var img *ImageStore var img *ImageStore
row := amdb.QueryRowContext(ctx, "SELECT imgid FROM imagestore WHERE typecode = ? AND ownerid = ?", typecode, owner)
var id int32 var id int32
err := row.Scan(&id) err := amdb.GetContext(ctx, &id, "SELECT imgid FROM imagestore WHERE typecode = ? AND ownerid = ?", typecode, owner)
switch err { switch err {
case nil: case nil:
img, err = AmLoadImage(ctx, id) img, err = AmLoadImage(ctx, id)
+1 -2
View File
@@ -273,8 +273,7 @@ func AmListIPBans(ctx context.Context) ([]IPBanEntry, error) {
// AmGetIPBan returns a single IP address ban structure. // AmGetIPBan returns a single IP address ban structure.
func AmGetIPBan(ctx context.Context, id int32) (*IPBanEntry, error) { func AmGetIPBan(ctx context.Context, id int32) (*IPBanEntry, error) {
var ban IPBanEntry var ban IPBanEntry
err := amdb.GetContext(ctx, &ban, "SELECT * FROM ipban WHERE id = ?", id) if err := amdb.GetContext(ctx, &ban, "SELECT * FROM ipban WHERE id = ?", id); err != nil {
if err != nil {
return nil, err return nil, err
} }
return &ban, nil return &ban, nil
+20 -28
View File
@@ -71,9 +71,8 @@ func (p *PostHeader) IsScribbled() bool {
// IsPublished returns true if the post has been published to the front page. // IsPublished returns true if the post has been published to the front page.
func (p *PostHeader) IsPublished(ctx context.Context) (bool, error) { func (p *PostHeader) IsPublished(ctx context.Context) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM postpublish WHERE postid = ?", p.PostId)
ct := 0 ct := 0
err := row.Scan(&ct) err := amdb.GetContext(ctx, &ct, "SELECT COUNT(*) FROM postpublish WHERE postid = ?", p.PostId)
return ct > 0, err return ct > 0, err
} }
@@ -237,16 +236,17 @@ func (p *PostHeader) PruneAttachment(ctx context.Context, u *User, comm *Communi
// Text returns the text associated with a post. // Text returns the text associated with a post.
func (p *PostHeader) Text(ctx context.Context) (string, error) { func (p *PostHeader) Text(ctx context.Context) (string, error) {
var pd PostData var pd PostData
if err := amdb.GetContext(ctx, &pd, "SELECT * FROM postdata WHERE postid = ?", p.PostId); err != nil { err := amdb.GetContext(ctx, &pd, "SELECT * FROM postdata WHERE postid = ?", p.PostId)
if err == sql.ErrNoRows { switch err {
return "", ErrNoPostData case nil:
}
return "", err
}
if pd.Data == nil { if pd.Data == nil {
return "", ErrNoPostData return "", ErrNoPostData
} }
return *pd.Data, nil return *pd.Data, nil
case sql.ErrNoRows:
return "", ErrNoPostData
}
return "", err
} }
// Link returns a link string to this post. // Link returns a link string to this post.
@@ -304,9 +304,8 @@ func (p *PostHeader) Scribble(ctx context.Context, u *User, comm *Community, ipa
} }
// Reread the scribble date. // Reread the scribble date.
row := tx.QueryRowContext(ctx, "SELECT scribble_date FROM posts WHERE postid = ?", p.PostId)
var newScribbleDate time.Time var newScribbleDate time.Time
if err = row.Scan(&newScribbleDate); err != nil { if err = tx.GetContext(ctx, &newScribbleDate, "SELECT scribble_date FROM posts WHERE postid = ?", p.PostId); err != nil {
return err return err
} }
@@ -367,10 +366,9 @@ func (p *PostHeader) Nuke(ctx context.Context, u *User, comm *Community, ipaddr
if _, err = tx.ExecContext(ctx, "UPDATE posts SET num = (num - 1) WHERE topicid = ? AND num > ?", p.TopicId, p.Num); err != nil { if _, err = tx.ExecContext(ctx, "UPDATE posts SET num = (num - 1) WHERE topicid = ? AND num > ?", p.TopicId, p.Num); err != nil {
return err return err
} }
row := tx.QueryRowContext(ctx, "SELECT top_message FROM topics WHERE topicid = ?", p.TopicId)
// Renumber phase 2 - reset the top message in this topic // Renumber phase 2 - reset the top message in this topic
var topMessage int32 var topMessage int32
if err = row.Scan(&topMessage); err != nil { if err = tx.GetContext(ctx, &topMessage, "SELECT top_message FROM topics WHERE topicid = ?", p.TopicId); err != nil {
return err return err
} }
topMessage-- topMessage--
@@ -400,9 +398,8 @@ func (p *PostHeader) Publish(ctx context.Context, comm *Community, publisher *Us
defer rollback() defer rollback()
// Check if we were already published. // Check if we were already published.
row := tx.QueryRowContext(ctx, "SELECT by_uid FROM postpublish WHERE postid = ?", p.PostId)
var tmp int32 var tmp int32
err := row.Scan(&tmp) err := tx.GetContext(ctx, &tmp, "SELECT by_uid FROM postpublish WHERE postid = ?", p.PostId)
if err == nil { if err == nil {
return errors.New("post already published") return errors.New("post already published")
} else if err != sql.ErrNoRows { } else if err != sql.ErrNoRows {
@@ -459,10 +456,8 @@ func (p *PostHeader) MoveTo(ctx context.Context, target *Topic, u *User, comm *C
return err return err
} }
// Read back the last update. // Read back the last update.
row := tx.QueryRowContext(ctx, "SELECT lastupdate FROM topics WHERE topicid = ?", target.TopicId)
var lastUpdate time.Time var lastUpdate time.Time
err = row.Scan(&lastUpdate) if err = tx.GetContext(ctx, &lastUpdate, "SELECT lastupdate FROM topics WHERE topicid = ?", target.TopicId); err != nil {
if err != nil {
return err return err
} }
@@ -574,11 +569,10 @@ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User,
} }
// Read back the post header. // Read back the post header.
var pd PostHeader var hdr PostHeader
if err := tx.GetContext(ctx, &pd, "SELECT * FROM posts WHERE postid = ?", xid); err != nil { if err := tx.GetContext(ctx, &hdr, "SELECT * FROM posts WHERE postid = ?", xid); err != nil {
return nil, err return nil, err
} }
hdr := &pd
// Add the post data. // Add the post data.
_, err = tx.ExecContext(ctx, "INSERT INTO postdata (postid, data) VALUES (?, ?)", hdr.PostId, post) _, err = tx.ExecContext(ctx, "INSERT INTO postdata (postid, data) VALUES (?, ?)", hdr.PostId, post)
@@ -611,7 +605,7 @@ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User,
AmStoreAudit(AmNewCommAudit(AuditConferencePostMessage, user.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), AmStoreAudit(AmNewCommAudit(AuditConferencePostMessage, user.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId),
fmt.Sprintf("topic=%d", topic.Number), fmt.Sprintf("post=%d", hdr.PostId), fmt.Sprintf("pseud=%s", *hdr.Pseud))) fmt.Sprintf("topic=%d", topic.Number), fmt.Sprintf("post=%d", hdr.PostId), fmt.Sprintf("pseud=%s", *hdr.Pseud)))
return hdr, nil return &hdr, nil
} }
/* AmGetPublishedPosts gets all posts published to the front page, up to the maximum number configured in the database. /* AmGetPublishedPosts gets all posts published to the front page, up to the maximum number configured in the database.
@@ -785,10 +779,10 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max
} }
// Get the count of matching posts. // Get the count of matching posts.
var row *sql.Row var count int
switch scope { switch scope {
case "global": case "global":
row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) err = amdb.GetContext(ctx, &count, `SELECT COUNT(*)
FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid
JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid
JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid
@@ -796,7 +790,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max
WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl
AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, searchTerms) AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, searchTerms)
case "community": case "community":
row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) err = amdb.GetContext(ctx, &count, `SELECT COUNT(*)
FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid
JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid
JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid
@@ -804,7 +798,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max
WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl
AND q.commid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, searchTerms) AND q.commid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, searchTerms)
case "conference": case "conference":
row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) err = amdb.GetContext(ctx, &count, `SELECT COUNT(*)
FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid
JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid
JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid
@@ -812,7 +806,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max
WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl
AND q.commid = ? AND c.confid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, conf.ConfId, searchTerms) AND q.commid = ? AND c.confid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, conf.ConfId, searchTerms)
case "topic": case "topic":
row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) err = amdb.GetContext(ctx, &count, `SELECT COUNT(*)
FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid
JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid
JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid
@@ -821,8 +815,6 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max
AND q.commid = ? AND c.confid = ? AND t.topicid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, AND q.commid = ? AND c.confid = ? AND t.topicid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`,
u.Uid, confService, comm.Id, conf.ConfId, topic.TopicId, searchTerms) u.Uid, confService, comm.Id, conf.ConfId, topic.TopicId, searchTerms)
} }
var count int
err = row.Scan(&count)
if err != nil { if err != nil {
log.Errorf("AmSearchPosts query 1 error %v", err) log.Errorf("AmSearchPosts query 1 error %v", err)
return nil, -1, err return nil, -1, err
+2 -4
View File
@@ -91,9 +91,8 @@ func AmRemoveSidebox(ctx context.Context, uid int32, boxid int32) error {
defer rollback() defer rollback()
// Get the old sequence number. // Get the old sequence number.
row := tx.QueryRowContext(ctx, "SELECT sequence FROM sideboxes WHERE uid = ? AND boxid = ?", uid, boxid)
var oldseq int32 var oldseq int32
err := row.Scan(&oldseq) err := tx.GetContext(ctx, &oldseq, "SELECT sequence FROM sideboxes WHERE uid = ? AND boxid = ?", uid, boxid)
if err != nil { if err != nil {
return err return err
} }
@@ -118,9 +117,8 @@ func AmAppendSidebox(ctx context.Context, uid int32, boxid int32, param *string)
tx, commit, rollback := transaction(ctx) tx, commit, rollback := transaction(ctx)
defer rollback() defer rollback()
row := tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM sideboxes WHERE uid = ?", uid)
var topseq int32 var topseq int32
err := row.Scan(&topseq) err := tx.GetContext(ctx, &topseq, "SELECT MAX(sequence) FROM sideboxes WHERE uid = ?", uid)
if err != nil { if err != nil {
return err return err
} }
+9 -16
View File
@@ -66,11 +66,10 @@ func (t *Topic) GetPost(ctx context.Context, num int32) (*PostHeader, error) {
return nil, fmt.Errorf("no post %d in topic %d", num, t.TopicId) return nil, fmt.Errorf("no post %d in topic %d", num, t.TopicId)
} }
var pd PostHeader var pd PostHeader
err := amdb.GetContext(ctx, &pd, "SELECT * FROM posts WHERE topicid = ? AND num = ?", t.TopicId, num) if err := amdb.GetContext(ctx, &pd, "SELECT * FROM posts WHERE topicid = ? AND num = ?", t.TopicId, num); err != nil {
if err == nil {
return &pd, nil
}
return nil, err return nil, err
}
return &pd, nil
} }
// GetLastRead returns the "last read" message for a user on a topic. // GetLastRead returns the "last read" message for a user on a topic.
@@ -78,9 +77,8 @@ func (t *Topic) GetLastRead(ctx context.Context, u *User) (int32, error) {
if u.IsAnon { if u.IsAnon {
return -1, nil return -1, nil
} }
row := amdb.QueryRowContext(ctx, "SELECT last_message FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
var rc int32 = -1 var rc int32 = -1
err := row.Scan(&rc) err := amdb.GetContext(ctx, &rc, "SELECT last_message FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return -1, nil return -1, nil
} }
@@ -106,9 +104,8 @@ func (t *Topic) SetLastRead(ctx context.Context, u *User, postNum int32) error {
// IsHidden tells us whether the user has the topic hidden. // IsHidden tells us whether the user has the topic hidden.
func (t *Topic) IsHidden(ctx context.Context, u *User) (bool, error) { func (t *Topic) IsHidden(ctx context.Context, u *User) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT hidden FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
rc := false rc := false
err := row.Scan(&rc) err := amdb.GetContext(ctx, &rc, "SELECT hidden FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
return rc, err return rc, err
} }
@@ -159,9 +156,8 @@ func (t *Topic) IsBozo(ctx context.Context, u *User, testUid int32) (bool, error
if u.IsAnon { if u.IsAnon {
return false, nil return false, nil
} }
row := amdb.QueryRowContext(ctx, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, testUid)
var tmp int32 var tmp int32
err := row.Scan(&tmp) err := amdb.GetContext(ctx, &tmp, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, testUid)
switch err { switch err {
case nil: case nil:
return true, nil return true, nil
@@ -176,9 +172,8 @@ func (t *Topic) SetBozo(ctx context.Context, u *User, subjectUid int32, bozo boo
var err error = nil var err error = nil
if !u.IsAnon { if !u.IsAnon {
if bozo { // Flipping the bozo bit! if bozo { // Flipping the bozo bit!
row := amdb.QueryRowContext(ctx, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, subjectUid)
var tmp int32 var tmp int32
err = row.Scan(&tmp) err = amdb.GetContext(ctx, &tmp, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, subjectUid)
switch err { switch err {
case nil: case nil:
return nil return nil
@@ -225,9 +220,8 @@ func (t *Topic) GetBozos(ctx context.Context, u *User) ([]TopicBozo, error) {
// IsSubscribed returns true if the given user is subscribed to receive E-mails of topic posts. // IsSubscribed returns true if the given user is subscribed to receive E-mails of topic posts.
func (t *Topic) IsSubscribed(ctx context.Context, u *User) (bool, error) { func (t *Topic) IsSubscribed(ctx context.Context, u *User) (bool, error) {
row := amdb.QueryRowContext(ctx, "SELECT subscribe FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
var rc bool var rc bool
err := row.Scan(&rc) err := amdb.GetContext(ctx, &rc, "SELECT subscribe FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid)
switch err { switch err {
case nil: case nil:
return rc, nil return rc, nil
@@ -370,9 +364,8 @@ func backgroundPurgeTopic(ctx context.Context, topicid int32) error {
defer rollback() defer rollback()
// Get some stats on the posts we have to remove. // Get some stats on the posts we have to remove.
row := tx.QueryRowContext(ctx, "SELECT MAX(postid) FROM posts WHERE topicid = ?", topicid)
var postMax int32 var postMax int32
err := row.Scan(&postMax) err := tx.GetContext(ctx, &postMax, "SELECT MAX(postid) FROM posts WHERE topicid = ?", topicid)
if err != nil { if err != nil {
return err return err
} }
+46 -42
View File
@@ -32,6 +32,9 @@ import (
"golang.org/x/text/message" "golang.org/x/text/message"
) )
// ErrNoUser is an error returned if the user is not found in the database.
var ErrNoUser error = errors.New("no such user")
// UserPrefs represents the user's preferences in a table (one row per user). // UserPrefs represents the user's preferences in a table (one row per user).
type UserPrefs struct { type UserPrefs struct {
Uid int32 `db:"uid"` // user ID Uid int32 `db:"uid"` // user ID
@@ -434,19 +437,21 @@ func (u *User) SetSecurityData(ctx context.Context, baseLevel uint16, lockout, v
* Standard Go error status * Standard Go error status
*/ */
func AmGetUser(ctx context.Context, uid int32) (*User, error) { func AmGetUser(ctx context.Context, uid int32) (*User, error) {
var err error = nil
getUserMutex.Lock() getUserMutex.Lock()
defer getUserMutex.Unlock() defer getUserMutex.Unlock()
rc, ok := userCache.Get(uid) if rc, ok := userCache.Get(uid); ok {
if !ok { return rc.(*User), nil
}
var user User var user User
if err = amdb.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid); err != nil { err := amdb.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid)
switch err {
case nil:
userCache.Add(uid, &user)
return &user, nil
case sql.ErrNoRows:
return nil, ErrNoUser
}
return nil, err return nil, err
}
rc = &user
userCache.Add(uid, rc)
}
return rc.(*User), err
} }
/* AmGetUserTx returns a reference to the specified user inside a transaction. /* AmGetUserTx returns a reference to the specified user inside a transaction.
@@ -459,19 +464,21 @@ func AmGetUser(ctx context.Context, uid int32) (*User, error) {
* Standard Go error status * Standard Go error status
*/ */
func AmGetUserTx(ctx context.Context, tx *sqlx.Tx, uid int32) (*User, error) { func AmGetUserTx(ctx context.Context, tx *sqlx.Tx, uid int32) (*User, error) {
var err error = nil
getUserMutex.Lock() getUserMutex.Lock()
defer getUserMutex.Unlock() defer getUserMutex.Unlock()
rc, ok := userCache.Get(uid) if rc, ok := userCache.Get(uid); ok {
if !ok { return rc.(*User), nil
}
var user User var user User
if err = tx.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid); err != nil { err := tx.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid)
switch err {
case nil:
userCache.Add(uid, &user)
return &user, nil
case sql.ErrNoRows:
return nil, ErrNoUser
}
return nil, err return nil, err
}
rc = &user
userCache.Add(uid, rc)
}
return rc.(*User), err
} }
/* AmGetUserByName returns a reference to the specified user. /* AmGetUserByName returns a reference to the specified user.
@@ -491,25 +498,26 @@ func AmGetUserByName(ctx context.Context, name string, tx *sqlx.Tx) (*User, erro
} else { } else {
err = amdb.GetContext(ctx, &user, "SELECT * FROM users WHERE username = ?", name) err = amdb.GetContext(ctx, &user, "SELECT * FROM users WHERE username = ?", name)
} }
if err != nil { switch err {
return nil, err case nil:
}
getUserMutex.Lock() getUserMutex.Lock()
rc, ok := userCache.Get(user.Uid) defer getUserMutex.Unlock()
if !ok { if rc, ok := userCache.Get(user.Uid); ok {
rc = &user
userCache.Add(user.Uid, rc)
}
getUserMutex.Unlock()
return rc.(*User), nil return rc.(*User), nil
} else {
userCache.Add(user.Uid, &user)
}
return &user, nil
case sql.ErrNoRows:
return nil, ErrNoUser
}
return nil, err
} }
// getAnonUserID retrieves the UID of the "anonymous" user from the database. // getAnonUserID retrieves the UID of the "anonymous" user from the database.
func getAnonUserID(ctx context.Context) (int32, error) { func getAnonUserID(ctx context.Context) (int32, error) {
if anonUid < 0 { if anonUid < 0 {
row := amdb.QueryRowContext(ctx, "SELECT uid FROM users WHERE is_anon = 1") if err := amdb.GetContext(ctx, &anonUid, "SELECT uid FROM users WHERE is_anon = 1"); err != nil {
err := row.Scan(&anonUid)
if err != nil {
return -1, err return -1, err
} }
} }
@@ -708,9 +716,8 @@ func AmCreateNewUser(ctx context.Context, username string, password string, remi
defer rollback() defer rollback()
// Test if the user name is already taken. // Test if the user name is already taken.
row := tx.QueryRowContext(ctx, "SELECT uid FROM users WHERE username = ?", username)
var tmpuid int32 var tmpuid int32
err := row.Scan(&tmpuid) err := tx.GetContext(ctx, &tmpuid, "SELECT uid FROM users WHERE username = ?", username)
if err == nil { if err == nil {
log.Warnf("username \"%s\" already exists", username) log.Warnf("username \"%s\" already exists", username)
return nil, errors.New("that user name already exists. Please try again") return nil, errors.New("that user name already exists. Please try again")
@@ -775,20 +782,18 @@ func AmCreateNewUser(ctx context.Context, username string, password string, remi
// internalGetProp is a helper used by the property functions. // internalGetProp is a helper used by the property functions.
func internalGetProp(ctx context.Context, uid int32, ndx int32) (*UserProperties, error) { func internalGetProp(ctx context.Context, uid int32, ndx int32) (*UserProperties, error) {
var err error = nil
key := fmt.Sprintf("%d:%d", uid, ndx) key := fmt.Sprintf("%d:%d", uid, ndx)
getUserPropMutex.Lock() getUserPropMutex.Lock()
defer getUserPropMutex.Unlock() defer getUserPropMutex.Unlock()
rc, ok := userPropCache.Get(key) if rc, ok := userPropCache.Get(key); ok {
if !ok { return rc.(*UserProperties), nil
}
var prop UserProperties var prop UserProperties
if err = amdb.GetContext(ctx, &prop, "SELECT * from propuser WHERE uid = ? AND ndx = ?", uid, ndx); err != nil { if err := amdb.GetContext(ctx, &prop, "SELECT * from propuser WHERE uid = ? AND ndx = ?", uid, ndx); err != nil {
return nil, err return nil, err
} }
rc = &prop userPropCache.Add(key, &prop)
userPropCache.Add(key, rc) return &prop, nil
}
return rc.(*UserProperties), nil
} }
/* AmGetUserProperty retrieves the value of a user property. /* AmGetUserProperty retrieves the value of a user property.
@@ -890,9 +895,8 @@ func AmSearchUsers(ctx context.Context, field int, oper int, term string, offset
return nil, -1, errors.New("invalid operator selector") return nil, -1, errors.New("invalid operator selector")
} }
q := queryPortion.String() q := queryPortion.String()
row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM users u, contacts c WHERE u.contactid = c.contactid AND u.is_anon = 0 AND "+q)
var total int var total int
err := row.Scan(&total) err := amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM users u, contacts c WHERE u.contactid = c.contactid AND u.is_anon = 0 AND "+q)
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
} }