diff --git a/database/audit.go b/database/audit.go index 22a34e7..68b0b38 100644 --- a/database/audit.go +++ b/database/audit.go @@ -11,7 +11,6 @@ package database import ( "context" - "database/sql" _ "embed" "fmt" "time" @@ -240,14 +239,13 @@ func AmStoreAudit(rec *AuditRecord) { // AmListAuditRecords lists a section of the audit records. func AmListAuditRecords(ctx context.Context, comm *Community, offset, max int) ([]AuditRecord, int, error) { - var row *sql.Row - if comm != nil { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM audit WHERE commid = ?", comm.Id) - } else { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM audit") - } + var err error var count int - err := row.Scan(&count) + if comm != nil { + err = amdb.GetContext(ctx, &count, "SELECT COUNT(*) FROM audit WHERE commid = ?", comm.Id) + } else { + err = amdb.GetContext(ctx, &count, "SELECT COUNT(*) FROM audit") + } if err != nil { return nil, -1, err } diff --git a/database/category.go b/database/category.go index 6073119..39accd9 100644 --- a/database/category.go +++ b/database/category.go @@ -65,9 +65,8 @@ func loadCategories(ctx context.Context) error { categoryMutex.Lock() defer categoryMutex.Unlock() if allCategories == nil { - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM refcategory") - var ncats int32 - if err := row.Scan(&ncats); err != nil { + var ncats int + if err := amdb.GetContext(ctx, &ncats, "SELECT COUNT(*) FROM refcategory"); err != nil { return err } allCategories = make([]Category, 0, ncats) @@ -226,10 +225,9 @@ func AmSearchCategories(ctx context.Context, oper int, term string, offset int, queryString.WriteString(" AND hide_search = 0") } q := queryString.String() - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM refcategory WHERE "+q) var total int - if err = row.Scan(&total); err != nil { - return nil, total, err + if err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM refcategory WHERE "+q); err != nil { + return nil, -1, err } if total == 0 { return make([]*Category, 0), 0, nil diff --git a/database/community.go b/database/community.go index 60eee10..24a8620 100644 --- a/database/community.go +++ b/database/community.go @@ -28,6 +28,9 @@ import ( "golang.org/x/text/language" ) +// ErrNoCommunity is an error returned for "no community found" errors +var ErrNoCommunity error = errors.New("no such community") + // Community struct contains the high level data for a community. type Community struct { Mutex sync.RWMutex @@ -221,18 +224,17 @@ func (c *Community) Membership(ctx context.Context, u *User) (bool, bool, uint16 // MemberCount returns the number of members in the community. func (c *Community) MemberCount(ctx context.Context, hidden bool) (int, error) { - var row *sql.Row - if hidden { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM commmember WHERE commid = ?", c.Id) - } else { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM commmember WHERE commid = ? AND hidden = 0", c.Id) - } var rc int - if err := row.Scan(&rc); err == nil { - return rc, nil + var err error + if hidden { + err = amdb.GetContext(ctx, &rc, "SELECT COUNT(*) FROM commmember WHERE commid = ?", c.Id) } else { - return -1, err + err = amdb.GetContext(ctx, &rc, "SELECT COUNT(*) FROM commmember WHERE commid = ? AND hidden = 0", c.Id) } + if err == nil { + return rc, nil + } + return -1, err } /* ListMembers lists or searches for community members matching certain criteria. @@ -294,12 +296,11 @@ func (c *Community) ListMembers(ctx context.Context, field int, oper int, term s query.WriteString(" AND m.hidden = 0") } q := query.String() - row := amdb.QueryRowContext(ctx, `SELECT COUNT(*) FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid - AND u.contactid = c.contactid`+q, c.Id) var total int var err error var rs *sql.Rows - if err = row.Scan(&total); err == nil { + if err = amdb.GetContext(ctx, &total, `SELECT COUNT(*) FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid + AND u.contactid = c.contactid`+q, c.Id); err == nil { if offset > 0 { rs, err = amdb.QueryContext(ctx, `SELECT m.uid FROM commmember m, users u, contacts c WHERE m.commid = ? AND m.uid = u.uid AND u.contactid = c.contactid`+q+" ORDER BY u.username LIMIT ? OFFSET ?", c.Id, max, offset) @@ -519,8 +520,7 @@ func (c *Community) SetProfileData(ctx context.Context, name string, alias strin c.CreateLevel = create_lvl c.DeleteLevel = delete_lvl c.JoinLevel = join_lvl - row := amdb.QueryRowContext(ctx, "SELECT lastupdate FROM communities WHERE commid = ?", c.Id) - err2 := row.Scan(&(c.LastUpdate)) + err2 := amdb.GetContext(ctx, &(c.LastUpdate), "SELECT lastupdate FROM communities WHERE commid = ?", c.Id) if err2 != nil { log.Errorf("SetProfileData scan error: %v", err2) } @@ -545,9 +545,8 @@ func (c *Community) Touch(ctx context.Context) error { defer c.Mutex.Unlock() _, err := amdb.ExecContext(ctx, "UPDATE communities SET lastaccess = NOW() WHERE commid = ?", c.Id) if err == nil { - row := amdb.QueryRowContext(ctx, "SELECT lastaccess FROM communities WHERE commid = ?", c.Id) var na time.Time - if err = row.Scan(&na); err == nil { + if err = amdb.GetContext(ctx, &na, "SELECT lastaccess FROM communities WHERE commid = ?", c.Id); err == nil { c.LastAccess = &na } } @@ -603,16 +602,19 @@ func (c *Community) GetMemberEMailAddrs(ctx context.Context) ([]string, error) { func AmGetCommunity(ctx context.Context, id int32) (*Community, error) { getCommunityMutex.Lock() defer getCommunityMutex.Unlock() - rc, ok := communityCache.Get(id) - if !ok { - var newcomm Community - if err := amdb.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id); err != nil { - return nil, err - } - rc = &newcomm - communityCache.Add(id, rc) + if rc, ok := communityCache.Get(id); ok { + return rc.(*Community), nil } - return rc.(*Community), nil + var newcomm Community + err := amdb.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id) + switch err { + case nil: + communityCache.Add(id, &newcomm) + return &newcomm, nil + case sql.ErrNoRows: + return nil, ErrNoCommunity + } + return nil, err } /* AmGetCommunityTx returns a reference to the specified community, in a transaction. @@ -627,16 +629,19 @@ func AmGetCommunity(ctx context.Context, id int32) (*Community, error) { func AmGetCommunityTx(ctx context.Context, tx *sqlx.Tx, id int32) (*Community, error) { getCommunityMutex.Lock() defer getCommunityMutex.Unlock() - rc, ok := communityCache.Get(id) - if !ok { - var newcomm Community - if err := tx.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id); err != nil { - return nil, err - } - rc = &newcomm - communityCache.Add(id, rc) + if rc, ok := communityCache.Get(id); ok { + return rc.(*Community), nil } - return rc.(*Community), nil + var newcomm Community + err := tx.GetContext(ctx, &newcomm, "SELECT * from communities WHERE commid = ?", id) + switch err { + case nil: + communityCache.Add(id, &newcomm) + return &newcomm, nil + case sql.ErrNoRows: + return nil, ErrNoCommunity + } + return nil, err } /* AmGetCommunityByAlias returns a reference to the specified community. @@ -648,13 +653,11 @@ func AmGetCommunityTx(ctx context.Context, tx *sqlx.Tx, id int32) (*Community, e * Standard Go error status (nil if community not found) */ func AmGetCommunityByAlias(ctx context.Context, alias string) (*Community, error) { - row := amdb.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias) var cid int32 - err := row.Scan(&cid) - if err == nil { - return AmGetCommunity(ctx, cid) + if err := amdb.GetContext(ctx, &cid, "SELECT commid FROM communities WHERE alias = ?", alias); err != nil { + return nil, err } - return nil, err + return AmGetCommunity(ctx, cid) } /* AmGetCommunityByAliasTx returns a reference to the specified community, within a transaction. @@ -667,13 +670,11 @@ func AmGetCommunityByAlias(ctx context.Context, alias string) (*Community, error * Standard Go error status (nil if community not found) */ func AmGetCommunityByAliasTx(ctx context.Context, tx *sqlx.Tx, alias string) (*Community, error) { - row := tx.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias) var cid int32 - err := row.Scan(&cid) - if err == nil { - return AmGetCommunityTx(ctx, tx, cid) + if err := tx.GetContext(ctx, &cid, "SELECT commid FROM communities WHERE alias = ?", alias); err != nil { + return nil, err } - return nil, err + return AmGetCommunityTx(ctx, tx, cid) } /* AmGetCommunityFromParam returns a reference to the specified community based on the parameter. @@ -793,20 +794,22 @@ func AmAutoJoinCommunities(ctx context.Context, user *User) error { // internalGetCommProp is a helper used by the community property functions. func internalGetCommProp(ctx context.Context, cid int32, ndx int32) (*CommunityProperties, error) { - var err error = nil key := fmt.Sprintf("%d:%d", cid, ndx) getCommunityPropMutex.Lock() defer getCommunityPropMutex.Unlock() - rc, ok := communityPropCache.Get(key) - if !ok { - var prop CommunityProperties - if err = amdb.GetContext(ctx, &prop, "SELECT * from propcomm WHERE cid = ? AND ndx = ?", cid, ndx); err != nil { - return nil, err - } - rc = &prop - communityPropCache.Add(key, rc) + if rc, ok := communityPropCache.Get(key); ok { + return rc.(*CommunityProperties), nil } - return rc.(*CommunityProperties), nil + var prop CommunityProperties + err := amdb.GetContext(ctx, &prop, "SELECT * from propcomm WHERE cid = ? AND ndx = ?", cid, ndx) + switch err { + case nil: + communityPropCache.Add(key, &prop) + return &prop, nil + case sql.ErrNoRows: + return nil, nil + } + return nil, err } /* AmGetCommunityProperty retrieves the value of a community property. @@ -882,10 +885,8 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i defer rollback() // validate alias does not already exist - row := tx.QueryRowContext(ctx, "SELECT commid FROM communities WHERE alias = ?", alias) var tmpcid int32 - err := row.Scan(&tmpcid) - if err != sql.ErrNoRows { + if err := tx.GetContext(ctx, &tmpcid, "SELECT commid FROM communities WHERE alias = ?", alias); err != sql.ErrNoRows { if err == nil { err = errors.New("a community with that alias already exists") } @@ -893,7 +894,7 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i } // establish the community record - _, err = tx.ExecContext(ctx, `INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl, + _, err := tx.ExecContext(ctx, `INSERT INTO communities (createdate, lastaccess, lastupdate, read_lvl, write_lvl, create_lvl, delete_lvl, join_lvl, host_uid, hide_dir, hide_search, commname, language, synopsis, rules, joinkey, alias) VALUES (NOW(), NOW(), NOW(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, AmRoleList("Community.Read").Default().Level(), AmRoleList("Community.Write").Default().Level(), @@ -950,14 +951,12 @@ func AmCreateCommunity(ctx context.Context, name string, alias string, hostUid i */ func AmGetCommunitiesForCategory(ctx context.Context, catid int32, offset int, max int, showAll bool) ([]*Community, int, error) { var err error - var row *sql.Row - if showAll { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities WHERE catid = ?", catid) - } else { - row = amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities WHERE catid = ? AND hide_dir = 0", catid) - } var total int - err = row.Scan(&total) + if showAll { + err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities WHERE catid = ?", catid) + } else { + err = amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities WHERE catid = ? AND hide_dir = 0", catid) + } if err != nil || total == 0 { return make([]*Community, 0), 0, err // short-circuit return } @@ -1043,9 +1042,8 @@ func AmSearchCommunities(ctx context.Context, field int, oper int, term string, queryPortion.WriteString(" AND hide_search = 0") } q := queryPortion.String() - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM communities "+q) var total int - err := row.Scan(&total) + err := amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM communities "+q) if err != nil || total == 0 { return make([]*Community, 0), 0, err // short-circuit return } @@ -1123,10 +1121,9 @@ func AmSearchCommunityMembers(ctx context.Context, c *Community, field int, oper return nil, -1, errors.New("invalid operator selector") } q := queryPortion.String() - row := amdb.QueryRowContext(ctx, `SELECT COUNT(*) FROM users u, contacts c, commmember m WHERE u.contactid = c.contactid AND u.uid = m.uid - AND m.commid = ? AND u.is_anon = 0 AND `+q, c.Id) var total int - err := row.Scan(&total) + err := amdb.GetContext(ctx, &total, `SELECT COUNT(*) FROM users u, contacts c, commmember m WHERE u.contactid = c.contactid AND u.uid = m.uid + AND m.commid = ? AND u.is_anon = 0 AND `+q, c.Id) if err != nil { return nil, -1, err } diff --git a/database/conference.go b/database/conference.go index a3e9c22..a4a1fa0 100644 --- a/database/conference.go +++ b/database/conference.go @@ -26,6 +26,9 @@ import ( log "github.com/sirupsen/logrus" ) +// ErrNoConference is an error thrown when a conference is not found. +var ErrNoConference error = errors.New("no such conference") + // Conference struct is the top-level structure for a conference. type Conference struct { Mutex sync.Mutex @@ -172,17 +175,14 @@ func (c *Conference) AliasesQ(ctx context.Context) []string { // AddAlias adds an alias to the conference. func (c *Conference) AddAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error { - row := amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias) tmp := "" - err := row.Scan(&tmp) - if err != sql.ErrNoRows { + if err := amdb.GetContext(ctx, &tmp, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias); err != sql.ErrNoRows { if err == nil { return fmt.Errorf("the alias '%s' is already in use by another conference", alias) } return err } - _, err = amdb.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", c.ConfId, alias) - if err != nil { + if _, err := amdb.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", c.ConfId, alias); err != nil { return err } @@ -192,17 +192,14 @@ func (c *Conference) AddAlias(ctx context.Context, alias string, u *User, comm * // RemoveAlias removes an alias from the conference. func (c *Conference) RemoveAlias(ctx context.Context, alias string, u *User, comm *Community, ipaddr string) error { - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM confalias WHERE confid = ?", c.ConfId) aliasCount := 0 - err := row.Scan(&aliasCount) - if err != nil { + if err := amdb.GetContext(ctx, &aliasCount, "SELECT COUNT(*) FROM confalias WHERE confid = ?", c.ConfId); err != nil { return err } if aliasCount == 1 { - row = amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias) tmp := "" - err = row.Scan(&tmp) + err := amdb.GetContext(ctx, &tmp, "SELECT alias FROM confalias WHERE confid = ? AND alias = ?", c.ConfId, alias) if err == nil { return errors.New("the conference must have at least one alias") } else if err != sql.ErrNoRows { @@ -254,9 +251,8 @@ func (c *Conference) HostsQ(ctx context.Context) []*User { // InCommunity returns true if the specified conference is in the community. func (c *Conference) InCommunity(ctx context.Context, comm *Community) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT commid FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId) var tmp int32 - err := row.Scan(&tmp) + err := amdb.GetContext(ctx, &tmp, "SELECT commid FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId) switch err { case nil: return true, nil @@ -268,9 +264,8 @@ func (c *Conference) InCommunity(ctx context.Context, comm *Community) (bool, er // HiddenInList returns whether or not this conference is hidden in the community's conference list. func (c *Conference) HiddenInList(ctx context.Context, comm *Community) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT hide_list FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId) var rc bool - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT hide_list FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId) switch err { case nil: return rc, nil @@ -317,9 +312,8 @@ func (c *Conference) Members(ctx context.Context) ([]ConferenceMember, error) { // Membership returns a membership flag and granted level for the user in this conference. func (c *Conference) Membership(ctx context.Context, u *User) (bool, uint16, error) { - row := amdb.QueryRowContext(ctx, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) var level uint16 - err := row.Scan(&level) + err := amdb.GetContext(ctx, &level, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) switch err { case nil: return true, level, nil @@ -335,9 +329,8 @@ func (c *Conference) SetMembership(ctx context.Context, u *User, level uint16, b _, err := amdb.ExecContext(ctx, "DELETE FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) return err } - row := amdb.QueryRowContext(ctx, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) var oldLevel uint16 - err := row.Scan(&oldLevel) + err := amdb.GetContext(ctx, &oldLevel, "SELECT granted_lvl FROM confmember WHERE confid = ? AND uid = ?", c.ConfId, u.Uid) switch err { case nil: if oldLevel == level { @@ -464,8 +457,7 @@ func (c *Conference) SetInfo(ctx context.Context, name, descr string, read_lvl, create_lvl, hide_lvl, nuke_lvl, change_lvl, delete_lvl, c.ConfId) if err == nil { var tmp Conference - err := amdb.GetContext(ctx, &tmp, "SELECT * FROM confs WHERE confid = ?", c.ConfId) - if err == nil { + if err = amdb.GetContext(ctx, &tmp, "SELECT * FROM confs WHERE confid = ?", c.ConfId); err == nil { if c.Name != tmp.Name { AmStoreAudit(AmNewCommAudit(AuditConferenceName, u.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", c.ConfId), fmt.Sprintf("name='%s'", tmp.Name))) } @@ -576,11 +568,10 @@ func (c *Conference) TouchPost(ctx context.Context, tx *sqlx.Tx, u *User, lastPo // UnreadMessages returns the total number of unread messages in a conference for a user. func (c *Conference) UnreadMessages(ctx context.Context, u *User) (int32, error) { - row := amdb.QueryRowContext(ctx, `SELECT SUM(t.top_message - IFNULL(s.last_message,-1)) + var rc int32 + err := amdb.GetContext(ctx, &rc, `SELECT SUM(t.top_message - IFNULL(s.last_message,-1)) FROM topics t LEFT JOIN topicsettings s ON t.topicid = s.topicid AND s.uid = ? WHERE t.confid = ? AND t.archived = 0 AND (s.hidden IS NULL OR s.hidden = 0)`, u.Uid, c.ConfId) - var rc int32 - err := row.Scan(&rc) return rc, err } @@ -600,10 +591,8 @@ func (c *Conference) Fixseen(ctx context.Context, u *User) error { defer rollback() // Get a count of topics beforehand. - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM topics WHERE confid = ?", c.ConfId) count := 0 - err := row.Scan(&count) - if err != nil { + if err := tx.GetContext(ctx, &count, "SELECT COUNT(*) FROM topics WHERE confid = ?", c.ConfId); err != nil { return err } @@ -661,9 +650,8 @@ func (c *Conference) GetCustomBlocks(ctx context.Context) (string, string, error func (c *Conference) SetCustomBlocks(ctx context.Context, topBlock, bottomBlock string) error { tx, commit, rollback := transaction(ctx) defer rollback() - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM confcustom WHERE confid = ?", c.ConfId) ct := 0 - err := row.Scan(&ct) + err := tx.GetContext(ctx, &ct, "SELECT COUNT(*) FROM confcustom WHERE confid = ?", c.ConfId) if err != nil { return err } @@ -872,18 +860,17 @@ func (c *Conference) Delete(ctx context.Context, comm *Community, u *User, ipadd defer getConferenceMutex.Unlock() // any references to conference other than this community? - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", c.ConfId, comm.Id) refCount := 0 - err := row.Scan(&refCount) - if err != nil { + if err := tx.GetContext(ctx, &refCount, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", c.ConfId, comm.Id); err != nil { return err } // break the link with the community - if _, err = tx.ExecContext(ctx, "DELETE FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId); err != nil { + if _, err := tx.ExecContext(ctx, "DELETE FROM commtoconf WHERE commid = ? AND confid = ?", comm.Id, c.ConfId); err != nil { return err } + var err error if refCount == 0 { // We have to delete all the conference core data now. _, err = tx.ExecContext(ctx, "DELETE FROM confs WHERE confid = ?", c.ConfId) @@ -943,9 +930,8 @@ func (*conferenceServiceVTable) OnDeleteCommunity(ctx context.Context, tx *sqlx. } for i, confid := range confids { // any references to conference other than this community? - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", confid, commid) refCount := 0 - err = row.Scan(&refCount) + err := tx.GetContext(ctx, &refCount, "SELECT COUNT(*) FROM commtoconf WHERE confid = ? AND commid <> ?", confid, commid) if err != nil { return err } @@ -1011,19 +997,21 @@ func (*conferenceServiceVTable) OnUserLeaveCommunity(context.Context, *sqlx.Tx, * Standard Go error status. */ func AmGetConference(ctx context.Context, id int32) (*Conference, error) { - var err error = nil getConferenceMutex.Lock() defer getConferenceMutex.Unlock() - rc, ok := conferenceCache.Get(id) - if !ok { - var conf Conference - if err = amdb.GetContext(ctx, &conf, "SELECT * from confs where confid = ?"); err != nil { - return nil, err - } - rc = &conf - conferenceCache.Add(id, rc) + if rc, ok := conferenceCache.Get(id); ok { + return rc.(*Conference), nil } - return rc.(*Conference), err + var conf Conference + err := amdb.GetContext(ctx, &conf, "SELECT * from confs where confid = ?") + switch err { + case nil: + conferenceCache.Add(id, &conf) + return &conf, nil + case sql.ErrNoRows: + return nil, ErrNoConference + } + return nil, err } /* AmGetConferenceByAlias returns a conference given its alias. @@ -1040,8 +1028,7 @@ func AmGetConferenceByAlias(ctx context.Context, alias string) (*Conference, err if ok { confid = xconf.(int32) } else { - row := amdb.QueryRowContext(ctx, "SELECT confid FROM confalias WHERE alias = ?", alias) - err := row.Scan(&confid) + err := amdb.GetContext(ctx, &confid, "SELECT confid FROM confalias WHERE alias = ?", alias) if err == sql.ErrNoRows { return nil, fmt.Errorf("alias not found: %s", alias) } else if err != nil { @@ -1061,9 +1048,8 @@ func AmGetConferenceByAlias(ctx context.Context, alias string) (*Conference, err * Standard Go error status. */ func AmGetConferenceContainingPost(ctx context.Context, postId int64) (*Conference, error) { - row := amdb.QueryRowContext(ctx, "SELECT t.confid FROM topics t, posts p WHERE p.postid = ? AND p.topicid = t.topicid", postId) var confId int32 - err := row.Scan(&confId) + err := amdb.GetContext(ctx, &confId, "SELECT t.confid FROM topics t, posts p WHERE p.postid = ? AND p.topicid = t.topicid", postId) if err == sql.ErrNoRows { return nil, fmt.Errorf("post not found: %d", postId) } else if err != nil { @@ -1082,10 +1068,9 @@ func AmGetConferenceContainingPost(ctx context.Context, postId int64) (*Conferen * Standard Go error status. */ func AmGetConferenceByAliasInCommunity(ctx context.Context, cid int32, alias string) (*Conference, error) { - row := amdb.QueryRowContext(ctx, `SELECT c.confid FROM commtoconf c, confalias a WHERE c.confid = a.confid - AND c.commid = ? AND a.alias = ?`, cid, alias) var confid int32 - err := row.Scan(&confid) + err := amdb.GetContext(ctx, &confid, `SELECT c.confid FROM commtoconf c, confalias a WHERE c.confid = a.confid + AND c.commid = ? AND a.alias = ?`, cid, alias) switch err { case nil: return AmGetConference(ctx, confid) @@ -1124,8 +1109,7 @@ func AmListConferences(ctx context.Context, cid int32, showHidden bool) ([]*Conf } } for i := range rc { - row := amdb.QueryRowContext(ctx, "SELECT alias FROM confalias WHERE confid = ?", rc[i].ConfId) - err = row.Scan(&(rc[i].Alias)) + err := amdb.GetContext(ctx, &(rc[i].Alias), "SELECT alias FROM confalias WHERE confid = ?", rc[i].ConfId) if err != nil { return nil, err } @@ -1140,20 +1124,22 @@ func AmListConferences(ctx context.Context, cid int32, showHidden bool) ([]*Conf // internalGetConfProp is a helper used by the conference property functions. func internalGetConfProp(ctx context.Context, confid int32, ndx int32) (*ConferenceProperties, error) { - var err error = nil key := fmt.Sprintf("%d:%d", confid, ndx) getConferencePropMutex.Lock() defer getConferencePropMutex.Unlock() - rc, ok := conferencePropCache.Get(key) - if !ok { - var prop ConferenceProperties - if err = amdb.GetContext(ctx, &prop, "SELECT * from propconf WHERE confid = ? AND ndx = ?", confid, ndx); err != nil { - return nil, err - } - rc = &prop - conferencePropCache.Add(key, rc) + if rc, ok := conferencePropCache.Get(key); ok { + return rc.(*ConferenceProperties), nil } - return rc.(*ConferenceProperties), nil + var prop ConferenceProperties + err := amdb.GetContext(ctx, &prop, "SELECT * from propconf WHERE confid = ? AND ndx = ?", confid, ndx) + switch err { + case nil: + conferencePropCache.Add(key, &prop) + return &prop, nil + case sql.ErrNoRows: + return nil, nil + } + return nil, err } /* AmGetConferenceProperty retrieves the value of a conference property. @@ -1268,9 +1254,8 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr defer getConferenceMutex.Unlock() // Ensure the alias is not in use. - row := tx.QueryRowContext(ctx, "SELECT confid FROM confalias WHERE alias = ?", alias) var tmp int32 - err := row.Scan(&tmp) + err := tx.GetContext(ctx, &tmp, "SELECT confid FROM confalias WHERE alias = ?", alias) if err == nil { return nil, fmt.Errorf("the alias '%s' is already in use by a different conference", alias) } else if err != sql.ErrNoRows { @@ -1294,29 +1279,24 @@ func AmCreateConference(ctx context.Context, comm *Community, name, alias, descr } // Attach the alias to the conference. - _, err = tx.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", rc.ConfId, alias) - if err != nil { + if _, err = tx.ExecContext(ctx, "INSERT INTO confalias (confid, alias) VALUES (?, ?)", rc.ConfId, alias); err != nil { return nil, err } // Get the current "last" sequence number. - row = tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM commtoconf WHERE commid = ?", comm.Id) var seq int - err = row.Scan(&seq) - if err != nil { + if err = tx.GetContext(ctx, &seq, "SELECT MAX(sequence) FROM commtoconf WHERE commid = ?", comm.Id); err != nil { return nil, err } // Link the conference into the community, and set the hide flag. - _, err = tx.ExecContext(ctx, "INSERT INTO commtoconf (commid, confid, sequence, hide_list) VALUES (?, ?, ?, ?)", comm.Id, rc.ConfId, - int16(seq+COMMTOCONF_SEQ_SPACING), hide_list) - if err != nil { + if _, err = tx.ExecContext(ctx, "INSERT INTO commtoconf (commid, confid, sequence, hide_list) VALUES (?, ?, ?, ?)", comm.Id, rc.ConfId, + int16(seq+COMMTOCONF_SEQ_SPACING), hide_list); err != nil { return nil, err } // Make the specified user the first host of the conference. - _, err = tx.ExecContext(ctx, "INSERT INTO confmember (confid, uid, granted_lvl) VALUES (?, ?, ?)", rc.ConfId, u.Uid, AmDefaultRole("Conference.NewHost").Level()) - if err != nil { + if _, err = tx.ExecContext(ctx, "INSERT INTO confmember (confid, uid, granted_lvl) VALUES (?, ?, ?)", rc.ConfId, u.Uid, AmDefaultRole("Conference.NewHost").Level()); err != nil { return nil, err } diff --git a/database/contactinfo.go b/database/contactinfo.go index 99ab872..93db6d7 100644 --- a/database/contactinfo.go +++ b/database/contactinfo.go @@ -56,8 +56,7 @@ type ContactInfo struct { // lookupCommunityContact looks up the ID of a contact for a community. func lookupCommunityContact(ctx context.Context, id int32) (int32, error) { var rc int32 = -1 - row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_commid = ?", id) - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT contactid FROM contacts WHERE owner_commid = ?", id) if err == sql.ErrNoRows { return -1, nil } @@ -67,8 +66,7 @@ func lookupCommunityContact(ctx context.Context, id int32) (int32, error) { // lookupUserContact looks up the ID of a contact for a user. func lookupUserContact(ctx context.Context, uid int32) (int32, error) { var rc int32 = -1 - row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid) - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid) if err == sql.ErrNoRows { return -1, nil } @@ -150,9 +148,8 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) ( } if !emailChange { // we don't THINK the E-mail address is changing, but we could be wrong... - row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE contactid = ? AND email = ?", ci.ContactId, ci.Email) var tmpcid int32 - err := row.Scan(&tmpcid) + err := amdb.GetContext(ctx, &tmpcid, "SELECT contactid FROM contacts WHERE contactid = ? AND email = ?", ci.ContactId, ci.Email) if err == sql.ErrNoRows { emailChange = true } else if err != nil { @@ -185,9 +182,7 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) ( contactCache.Add(ci.ContactId, ci) } // Refresh the last update date. - row := amdb.QueryRowContext(ctx, "SELECT lastupdate FROM contacts WHERE contactid = ?", ci.ContactId) - err := row.Scan(&(ci.LastUpdate)) - if err != nil { + if err := amdb.GetContext(ctx, &(ci.LastUpdate), "SELECT lastupdate FROM contacts WHERE contactid = ?", ci.ContactId); err != nil { return false, err } if ci.OwnerCommId < 0 { @@ -197,7 +192,7 @@ func (ci *ContactInfo) Save(ctx context.Context, changer *User, ipaddr string) ( } else { AmStoreAudit(AmNewCommAudit(AuditCommunityContactInfo, changer.Uid, ci.OwnerCommId, ipaddr, fmt.Sprintf("contactid=%d", ci.ContactId))) } - return emailChange, err + return emailChange, nil } // Clone makes a copy of the ContactInfo. @@ -251,11 +246,10 @@ func setupContactsCache() { // internalContactInfo retrieves the contact info from the database. func internalContactInfo(ctx context.Context, id int32) (*ContactInfo, error) { var cinf ContactInfo - err := amdb.GetContext(ctx, &cinf, "SELECT * from contacts WHERE contactid = ?", id) - if err == nil { - return &cinf, nil + if err := amdb.GetContext(ctx, &cinf, "SELECT * from contacts WHERE contactid = ?", id); err != nil { + return nil, err } - return nil, err + return &cinf, nil } /* AmGetContactInfo retrieves the contact info for a given identifier. @@ -292,9 +286,8 @@ func AmGetContactInfo(ctx context.Context, id int32) (*ContactInfo, error) { * Standard Go error status. */ func AmGetContactInfoForUser(ctx context.Context, uid int32) (*ContactInfo, error) { - row := amdb.QueryRowContext(ctx, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid) var cid int32 - err := row.Scan(&cid) + err := amdb.GetContext(ctx, &cid, "SELECT contactid FROM contacts WHERE owner_uid = ? AND owner_commid = -1", uid) switch err { case nil: return AmGetContactInfo(ctx, cid) diff --git a/database/emailban.go b/database/emailban.go index c852853..70494f4 100644 --- a/database/emailban.go +++ b/database/emailban.go @@ -23,9 +23,8 @@ import ( * Standard Go error status. */ func AmIsEmailAddressBanned(ctx context.Context, address string) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT by_uid FROM emailban WHERE address = ?", address) var uid int32 - err := row.Scan(&uid) + err := amdb.GetContext(ctx, &uid, "SELECT by_uid FROM emailban WHERE address = ?", address) switch err { case nil: return true, nil diff --git a/database/globals.go b/database/globals.go index 6711df5..35d4147 100644 --- a/database/globals.go +++ b/database/globals.go @@ -108,8 +108,7 @@ func AmGlobals(ctx context.Context) (*Globals, error) { defer globalsMutex.Unlock() if theGlobals == nil { var g Globals - err := amdb.GetContext(ctx, &g, "SELECT * FROM globals") - if err != nil { + if err := amdb.GetContext(ctx, &g, "SELECT * FROM globals"); err != nil { return nil, err } theGlobals = &g @@ -146,8 +145,7 @@ func AmGetGlobalProperty(ctx context.Context, index int32) (string, error) { var err error = nil rc, ok := globalProps[index] if !ok { - row := amdb.QueryRowContext(ctx, "SELECT data FROM propglobal WHERE ndx = ?", index) - err = row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT data FROM propglobal WHERE ndx = ?", index) switch err { case nil: globalProps[index] = rc @@ -172,9 +170,8 @@ func AmSetGlobalProperty(ctx context.Context, index int32, value string) error { defer globalPropMutex.Unlock() _, updateMode := globalProps[index] if !updateMode { - row := amdb.QueryRowContext(ctx, "SELECT data FROM propglobal WHERE ndx = ?", index) var tmpdata string - err := row.Scan(&tmpdata) + err := amdb.GetContext(ctx, &tmpdata, "SELECT data FROM propglobal WHERE ndx = ?", index) switch err { case nil: updateMode = true diff --git a/database/hotlist.go b/database/hotlist.go index d188f6f..3af5666 100644 --- a/database/hotlist.go +++ b/database/hotlist.go @@ -108,15 +108,13 @@ func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error defer rollback() var newseq int16 - row := tx.QueryRowContext(ctx, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) - err := row.Scan(&newseq) + err := tx.GetContext(ctx, &newseq, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) if err == nil { return errors.New("community/conference already exist in hotlist") } else if err != sql.ErrNoRows { return err } - row = tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM confhotlist WHERE uid = ?", u.Uid) - err = row.Scan(&newseq) + err = tx.GetContext(ctx, &newseq, "SELECT MAX(sequence) FROM confhotlist WHERE uid = ?", u.Uid) if err == sql.ErrNoRows { newseq = 0 } else if err != nil { @@ -135,9 +133,8 @@ func AmAppendToHotlist(ctx context.Context, u *User, commid, confid int32) error // AmIsInHotlist returns true if the community/conference pair is in the hotlist. func AmIsInHotlist(ctx context.Context, u *User, commid, confid int32) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) var tmp int16 - err := row.Scan(&tmp) + err := amdb.GetContext(ctx, &tmp, "SELECT sequence FROM confhotlist WHERE uid = ? AND commid = ? AND confid = ?", u.Uid, commid, confid) switch err { case nil: return true, nil diff --git a/database/imagestore.go b/database/imagestore.go index 3fb67d4..b2ff4be 100644 --- a/database/imagestore.go +++ b/database/imagestore.go @@ -61,8 +61,7 @@ func (img *ImageStore) Save(ctx context.Context) error { */ func AmLoadImage(ctx context.Context, id int32) (*ImageStore, error) { var imgdata ImageStore - err := amdb.GetContext(ctx, &imgdata, "SELECT * FROM imagestore WHERE imgid = ?", id) - if err != nil { + if err := amdb.GetContext(ctx, &imgdata, "SELECT * FROM imagestore WHERE imgid = ?", id); err != nil { return nil, err } return &imgdata, nil @@ -81,9 +80,8 @@ func AmLoadImage(ctx context.Context, id int32) (*ImageStore, error) { */ func AmStoreImage(ctx context.Context, typecode int16, owner int32, mimetype string, data []byte) (*ImageStore, error) { var img *ImageStore - row := amdb.QueryRowContext(ctx, "SELECT imgid FROM imagestore WHERE typecode = ? AND ownerid = ?", typecode, owner) var id int32 - err := row.Scan(&id) + err := amdb.GetContext(ctx, &id, "SELECT imgid FROM imagestore WHERE typecode = ? AND ownerid = ?", typecode, owner) switch err { case nil: img, err = AmLoadImage(ctx, id) diff --git a/database/ipban.go b/database/ipban.go index aadebbc..5ae5ce1 100644 --- a/database/ipban.go +++ b/database/ipban.go @@ -273,8 +273,7 @@ func AmListIPBans(ctx context.Context) ([]IPBanEntry, error) { // AmGetIPBan returns a single IP address ban structure. func AmGetIPBan(ctx context.Context, id int32) (*IPBanEntry, error) { var ban IPBanEntry - err := amdb.GetContext(ctx, &ban, "SELECT * FROM ipban WHERE id = ?", id) - if err != nil { + if err := amdb.GetContext(ctx, &ban, "SELECT * FROM ipban WHERE id = ?", id); err != nil { return nil, err } return &ban, nil diff --git a/database/post.go b/database/post.go index bb8d3d0..07148c9 100644 --- a/database/post.go +++ b/database/post.go @@ -71,9 +71,8 @@ func (p *PostHeader) IsScribbled() bool { // IsPublished returns true if the post has been published to the front page. func (p *PostHeader) IsPublished(ctx context.Context) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM postpublish WHERE postid = ?", p.PostId) ct := 0 - err := row.Scan(&ct) + err := amdb.GetContext(ctx, &ct, "SELECT COUNT(*) FROM postpublish WHERE postid = ?", p.PostId) return ct > 0, err } @@ -237,16 +236,17 @@ func (p *PostHeader) PruneAttachment(ctx context.Context, u *User, comm *Communi // Text returns the text associated with a post. func (p *PostHeader) Text(ctx context.Context) (string, error) { var pd PostData - if err := amdb.GetContext(ctx, &pd, "SELECT * FROM postdata WHERE postid = ?", p.PostId); err != nil { - if err == sql.ErrNoRows { + err := amdb.GetContext(ctx, &pd, "SELECT * FROM postdata WHERE postid = ?", p.PostId) + switch err { + case nil: + if pd.Data == nil { return "", ErrNoPostData } - return "", err - } - if pd.Data == nil { + return *pd.Data, nil + case sql.ErrNoRows: return "", ErrNoPostData } - return *pd.Data, nil + return "", err } // Link returns a link string to this post. @@ -304,9 +304,8 @@ func (p *PostHeader) Scribble(ctx context.Context, u *User, comm *Community, ipa } // Reread the scribble date. - row := tx.QueryRowContext(ctx, "SELECT scribble_date FROM posts WHERE postid = ?", p.PostId) var newScribbleDate time.Time - if err = row.Scan(&newScribbleDate); err != nil { + if err = tx.GetContext(ctx, &newScribbleDate, "SELECT scribble_date FROM posts WHERE postid = ?", p.PostId); err != nil { return err } @@ -367,10 +366,9 @@ func (p *PostHeader) Nuke(ctx context.Context, u *User, comm *Community, ipaddr if _, err = tx.ExecContext(ctx, "UPDATE posts SET num = (num - 1) WHERE topicid = ? AND num > ?", p.TopicId, p.Num); err != nil { return err } - row := tx.QueryRowContext(ctx, "SELECT top_message FROM topics WHERE topicid = ?", p.TopicId) // Renumber phase 2 - reset the top message in this topic var topMessage int32 - if err = row.Scan(&topMessage); err != nil { + if err = tx.GetContext(ctx, &topMessage, "SELECT top_message FROM topics WHERE topicid = ?", p.TopicId); err != nil { return err } topMessage-- @@ -400,9 +398,8 @@ func (p *PostHeader) Publish(ctx context.Context, comm *Community, publisher *Us defer rollback() // Check if we were already published. - row := tx.QueryRowContext(ctx, "SELECT by_uid FROM postpublish WHERE postid = ?", p.PostId) var tmp int32 - err := row.Scan(&tmp) + err := tx.GetContext(ctx, &tmp, "SELECT by_uid FROM postpublish WHERE postid = ?", p.PostId) if err == nil { return errors.New("post already published") } else if err != sql.ErrNoRows { @@ -459,10 +456,8 @@ func (p *PostHeader) MoveTo(ctx context.Context, target *Topic, u *User, comm *C return err } // Read back the last update. - row := tx.QueryRowContext(ctx, "SELECT lastupdate FROM topics WHERE topicid = ?", target.TopicId) var lastUpdate time.Time - err = row.Scan(&lastUpdate) - if err != nil { + if err = tx.GetContext(ctx, &lastUpdate, "SELECT lastupdate FROM topics WHERE topicid = ?", target.TopicId); err != nil { return err } @@ -574,11 +569,10 @@ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User, } // Read back the post header. - var pd PostHeader - if err := tx.GetContext(ctx, &pd, "SELECT * FROM posts WHERE postid = ?", xid); err != nil { + var hdr PostHeader + if err := tx.GetContext(ctx, &hdr, "SELECT * FROM posts WHERE postid = ?", xid); err != nil { return nil, err } - hdr := &pd // Add the post data. _, err = tx.ExecContext(ctx, "INSERT INTO postdata (postid, data) VALUES (?, ?)", hdr.PostId, post) @@ -611,7 +605,7 @@ func AmNewPost(ctx context.Context, conf *Conference, topic *Topic, user *User, AmStoreAudit(AmNewCommAudit(AuditConferencePostMessage, user.Uid, comm.Id, ipaddr, fmt.Sprintf("confid=%d", conf.ConfId), fmt.Sprintf("topic=%d", topic.Number), fmt.Sprintf("post=%d", hdr.PostId), fmt.Sprintf("pseud=%s", *hdr.Pseud))) - return hdr, nil + return &hdr, nil } /* AmGetPublishedPosts gets all posts published to the front page, up to the maximum number configured in the database. @@ -785,10 +779,10 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max } // Get the count of matching posts. - var row *sql.Row + var count int switch scope { case "global": - row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) + err = amdb.GetContext(ctx, &count, `SELECT COUNT(*) FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid @@ -796,7 +790,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, searchTerms) case "community": - row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) + err = amdb.GetContext(ctx, &count, `SELECT COUNT(*) FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid @@ -804,7 +798,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl AND q.commid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, searchTerms) case "conference": - row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) + err = amdb.GetContext(ctx, &count, `SELECT COUNT(*) FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid @@ -812,7 +806,7 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max WHERE u.uid = ? AND f.ftr_code = ? AND GREATEST(u.base_lvl,m.granted_lvl,s.granted_lvl,IFNULL(x.granted_lvl,0)) >= c.read_lvl AND q.commid = ? AND c.confid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, conf.ConfId, searchTerms) case "topic": - row = amdb.QueryRowContext(ctx, `SELECT COUNT(*) + err = amdb.GetContext(ctx, &count, `SELECT COUNT(*) FROM communities q JOIN commtoconf s ON s.commid = q.commid JOIN confs c ON c.confid = s.confid JOIN commmember m ON m.commid = q.commid JOIN users u ON u.uid = m.uid JOIN commftrs f ON f.commid = q.commid JOIN topics t ON t.confid = c.confid JOIN posts p ON p.topicid = t.topicid JOIN postdata d ON d.postid = p.postid JOIN users u2 ON u2.uid = p.creator_uid @@ -821,8 +815,6 @@ func AmSearchPosts(ctx context.Context, searchTerms string, u *User, offset, max AND q.commid = ? AND c.confid = ? AND t.topicid = ? AND p.scribble_uid IS NULL AND MATCH(d.data) AGAINST (?)`, u.Uid, confService, comm.Id, conf.ConfId, topic.TopicId, searchTerms) } - var count int - err = row.Scan(&count) if err != nil { log.Errorf("AmSearchPosts query 1 error %v", err) return nil, -1, err diff --git a/database/sidebox.go b/database/sidebox.go index 73654ab..3a17fec 100644 --- a/database/sidebox.go +++ b/database/sidebox.go @@ -91,9 +91,8 @@ func AmRemoveSidebox(ctx context.Context, uid int32, boxid int32) error { defer rollback() // Get the old sequence number. - row := tx.QueryRowContext(ctx, "SELECT sequence FROM sideboxes WHERE uid = ? AND boxid = ?", uid, boxid) var oldseq int32 - err := row.Scan(&oldseq) + err := tx.GetContext(ctx, &oldseq, "SELECT sequence FROM sideboxes WHERE uid = ? AND boxid = ?", uid, boxid) if err != nil { return err } @@ -118,9 +117,8 @@ func AmAppendSidebox(ctx context.Context, uid int32, boxid int32, param *string) tx, commit, rollback := transaction(ctx) defer rollback() - row := tx.QueryRowContext(ctx, "SELECT MAX(sequence) FROM sideboxes WHERE uid = ?", uid) var topseq int32 - err := row.Scan(&topseq) + err := tx.GetContext(ctx, &topseq, "SELECT MAX(sequence) FROM sideboxes WHERE uid = ?", uid) if err != nil { return err } diff --git a/database/topic.go b/database/topic.go index 91a6e00..4cde109 100644 --- a/database/topic.go +++ b/database/topic.go @@ -66,11 +66,10 @@ func (t *Topic) GetPost(ctx context.Context, num int32) (*PostHeader, error) { return nil, fmt.Errorf("no post %d in topic %d", num, t.TopicId) } var pd PostHeader - err := amdb.GetContext(ctx, &pd, "SELECT * FROM posts WHERE topicid = ? AND num = ?", t.TopicId, num) - if err == nil { - return &pd, nil + if err := amdb.GetContext(ctx, &pd, "SELECT * FROM posts WHERE topicid = ? AND num = ?", t.TopicId, num); err != nil { + return nil, err } - return nil, err + return &pd, nil } // GetLastRead returns the "last read" message for a user on a topic. @@ -78,9 +77,8 @@ func (t *Topic) GetLastRead(ctx context.Context, u *User) (int32, error) { if u.IsAnon { return -1, nil } - row := amdb.QueryRowContext(ctx, "SELECT last_message FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) var rc int32 = -1 - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT last_message FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) if err == sql.ErrNoRows { return -1, nil } @@ -106,9 +104,8 @@ func (t *Topic) SetLastRead(ctx context.Context, u *User, postNum int32) error { // IsHidden tells us whether the user has the topic hidden. func (t *Topic) IsHidden(ctx context.Context, u *User) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT hidden FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) rc := false - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT hidden FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) return rc, err } @@ -159,9 +156,8 @@ func (t *Topic) IsBozo(ctx context.Context, u *User, testUid int32) (bool, error if u.IsAnon { return false, nil } - row := amdb.QueryRowContext(ctx, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, testUid) var tmp int32 - err := row.Scan(&tmp) + err := amdb.GetContext(ctx, &tmp, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, testUid) switch err { case nil: return true, nil @@ -176,9 +172,8 @@ func (t *Topic) SetBozo(ctx context.Context, u *User, subjectUid int32, bozo boo var err error = nil if !u.IsAnon { if bozo { // Flipping the bozo bit! - row := amdb.QueryRowContext(ctx, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, subjectUid) var tmp int32 - err = row.Scan(&tmp) + err = amdb.GetContext(ctx, &tmp, "SELECT bozo_uid FROM topicbozo WHERE topicid = ? AND uid = ? AND bozo_uid = ?", t.TopicId, u.Uid, subjectUid) switch err { case nil: return nil @@ -225,9 +220,8 @@ func (t *Topic) GetBozos(ctx context.Context, u *User) ([]TopicBozo, error) { // IsSubscribed returns true if the given user is subscribed to receive E-mails of topic posts. func (t *Topic) IsSubscribed(ctx context.Context, u *User) (bool, error) { - row := amdb.QueryRowContext(ctx, "SELECT subscribe FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) var rc bool - err := row.Scan(&rc) + err := amdb.GetContext(ctx, &rc, "SELECT subscribe FROM topicsettings WHERE topicid = ? AND uid = ?", t.TopicId, u.Uid) switch err { case nil: return rc, nil @@ -370,9 +364,8 @@ func backgroundPurgeTopic(ctx context.Context, topicid int32) error { defer rollback() // Get some stats on the posts we have to remove. - row := tx.QueryRowContext(ctx, "SELECT MAX(postid) FROM posts WHERE topicid = ?", topicid) var postMax int32 - err := row.Scan(&postMax) + err := tx.GetContext(ctx, &postMax, "SELECT MAX(postid) FROM posts WHERE topicid = ?", topicid) if err != nil { return err } diff --git a/database/user.go b/database/user.go index 3bceff7..4e0d0d7 100644 --- a/database/user.go +++ b/database/user.go @@ -32,6 +32,9 @@ import ( "golang.org/x/text/message" ) +// ErrNoUser is an error returned if the user is not found in the database. +var ErrNoUser error = errors.New("no such user") + // UserPrefs represents the user's preferences in a table (one row per user). type UserPrefs struct { Uid int32 `db:"uid"` // user ID @@ -434,19 +437,21 @@ func (u *User) SetSecurityData(ctx context.Context, baseLevel uint16, lockout, v * Standard Go error status */ func AmGetUser(ctx context.Context, uid int32) (*User, error) { - var err error = nil getUserMutex.Lock() defer getUserMutex.Unlock() - rc, ok := userCache.Get(uid) - if !ok { - var user User - if err = amdb.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid); err != nil { - return nil, err - } - rc = &user - userCache.Add(uid, rc) + if rc, ok := userCache.Get(uid); ok { + return rc.(*User), nil } - return rc.(*User), err + var user User + err := amdb.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid) + switch err { + case nil: + userCache.Add(uid, &user) + return &user, nil + case sql.ErrNoRows: + return nil, ErrNoUser + } + return nil, err } /* AmGetUserTx returns a reference to the specified user inside a transaction. @@ -459,19 +464,21 @@ func AmGetUser(ctx context.Context, uid int32) (*User, error) { * Standard Go error status */ func AmGetUserTx(ctx context.Context, tx *sqlx.Tx, uid int32) (*User, error) { - var err error = nil getUserMutex.Lock() defer getUserMutex.Unlock() - rc, ok := userCache.Get(uid) - if !ok { - var user User - if err = tx.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid); err != nil { - return nil, err - } - rc = &user - userCache.Add(uid, rc) + if rc, ok := userCache.Get(uid); ok { + return rc.(*User), nil } - return rc.(*User), err + var user User + err := tx.GetContext(ctx, &user, "SELECT * from users WHERE uid = ?", uid) + switch err { + case nil: + userCache.Add(uid, &user) + return &user, nil + case sql.ErrNoRows: + return nil, ErrNoUser + } + return nil, err } /* AmGetUserByName returns a reference to the specified user. @@ -491,25 +498,26 @@ func AmGetUserByName(ctx context.Context, name string, tx *sqlx.Tx) (*User, erro } else { err = amdb.GetContext(ctx, &user, "SELECT * FROM users WHERE username = ?", name) } - if err != nil { - return nil, err + switch err { + case nil: + getUserMutex.Lock() + defer getUserMutex.Unlock() + if rc, ok := userCache.Get(user.Uid); ok { + return rc.(*User), nil + } else { + userCache.Add(user.Uid, &user) + } + return &user, nil + case sql.ErrNoRows: + return nil, ErrNoUser } - getUserMutex.Lock() - rc, ok := userCache.Get(user.Uid) - if !ok { - rc = &user - userCache.Add(user.Uid, rc) - } - getUserMutex.Unlock() - return rc.(*User), nil + return nil, err } // getAnonUserID retrieves the UID of the "anonymous" user from the database. func getAnonUserID(ctx context.Context) (int32, error) { if anonUid < 0 { - row := amdb.QueryRowContext(ctx, "SELECT uid FROM users WHERE is_anon = 1") - err := row.Scan(&anonUid) - if err != nil { + if err := amdb.GetContext(ctx, &anonUid, "SELECT uid FROM users WHERE is_anon = 1"); err != nil { return -1, err } } @@ -708,9 +716,8 @@ func AmCreateNewUser(ctx context.Context, username string, password string, remi defer rollback() // Test if the user name is already taken. - row := tx.QueryRowContext(ctx, "SELECT uid FROM users WHERE username = ?", username) var tmpuid int32 - err := row.Scan(&tmpuid) + err := tx.GetContext(ctx, &tmpuid, "SELECT uid FROM users WHERE username = ?", username) if err == nil { log.Warnf("username \"%s\" already exists", username) return nil, errors.New("that user name already exists. Please try again") @@ -775,20 +782,18 @@ func AmCreateNewUser(ctx context.Context, username string, password string, remi // internalGetProp is a helper used by the property functions. func internalGetProp(ctx context.Context, uid int32, ndx int32) (*UserProperties, error) { - var err error = nil key := fmt.Sprintf("%d:%d", uid, ndx) getUserPropMutex.Lock() defer getUserPropMutex.Unlock() - rc, ok := userPropCache.Get(key) - if !ok { - var prop UserProperties - if err = amdb.GetContext(ctx, &prop, "SELECT * from propuser WHERE uid = ? AND ndx = ?", uid, ndx); err != nil { - return nil, err - } - rc = &prop - userPropCache.Add(key, rc) + if rc, ok := userPropCache.Get(key); ok { + return rc.(*UserProperties), nil } - return rc.(*UserProperties), nil + var prop UserProperties + if err := amdb.GetContext(ctx, &prop, "SELECT * from propuser WHERE uid = ? AND ndx = ?", uid, ndx); err != nil { + return nil, err + } + userPropCache.Add(key, &prop) + return &prop, nil } /* AmGetUserProperty retrieves the value of a user property. @@ -890,9 +895,8 @@ func AmSearchUsers(ctx context.Context, field int, oper int, term string, offset return nil, -1, errors.New("invalid operator selector") } q := queryPortion.String() - row := amdb.QueryRowContext(ctx, "SELECT COUNT(*) FROM users u, contacts c WHERE u.contactid = c.contactid AND u.is_anon = 0 AND "+q) var total int - err := row.Scan(&total) + err := amdb.GetContext(ctx, &total, "SELECT COUNT(*) FROM users u, contacts c WHERE u.contactid = c.contactid AND u.is_anon = 0 AND "+q) if err != nil { return nil, -1, err }