diff --git a/.gitignore b/.gitignore index 1a13e8e..e78a3fb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ amsterdam __debug_bin* +# Ignore test files +/test.yaml \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 848c589..79982f5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "program": "${workspaceFolder}" }, { - "name": "Amsterdam Test Config", + "name": "Test Config", "type": "go", "request": "launch", "mode": "auto", diff --git a/ui/amcontext.go b/ui/amcontext.go index 31a2447..4bdb23e 100644 --- a/ui/amcontext.go +++ b/ui/amcontext.go @@ -135,7 +135,7 @@ func (c *amContext) Ctx() context.Context { // CurrentCommunity returns the current community, if one's been set. func (c *amContext) CurrentCommunity() *database.Community { if c.community == nil { - cv, ok := c.session.Values["lastCommunity"] + cv, ok := AmSessionGet(c.session, "lastCommunity") if ok && !c.CurrentUser().IsAnon { c.SetCommunityContext(fmt.Sprintf("%d", cv)) } @@ -146,7 +146,14 @@ func (c *amContext) CurrentCommunity() *database.Community { // CurrentUser returns the current user from the session. func (c *amContext) CurrentUser() *database.User { if c.user == nil { - u, err := database.AmGetUser(c.echoContext.Request().Context(), AmSessionUid(c.session)) + id, ok := AmSessionUid(c.session) + var err error + var u *database.User + if ok { + u, err = database.AmGetUser(c.echoContext.Request().Context(), id) + } else { + u, err = database.AmGetAnonUser(c.echoContext.Request().Context()) + } if err != nil { log.Errorf("unable to retrieve current user") } @@ -158,7 +165,15 @@ func (c *amContext) CurrentUser() *database.User { // CurrentUserId returns the current user ID. func (c *amContext) CurrentUserId() int32 { - return AmSessionUid(c.session) + rc, ok := AmSessionUid(c.session) + if ok { + return rc + } + u, err := database.AmGetAnonUser(c.echoContext.Request().Context()) + if err == nil { + return u.Uid + } + return 0 } // EffectiveLevel returns the user's effective access level (in terms of current community, if any). @@ -241,7 +256,12 @@ func (c *amContext) IsMemberLocked() bool { // LeftMenu returns the current left menu selector. func (c *amContext) LeftMenu() string { - return c.session.Values["leftMenu"].(string) + rc, ok := AmSessionGet(c.session, "leftMenu") + if ok { + return rc.(string) + } else { + return "top" + } } // RC returns the HTTP result code for the current operation. @@ -370,7 +390,7 @@ func (c *amContext) SetCommunityContext(param string) error { c.effectiveLevel = level } if mbr { - c.session.Values["lastCommunity"] = comm.Id + AmSessionPut(c.session, "lastCommunity", comm.Id) } } return nil @@ -378,7 +398,7 @@ func (c *amContext) SetCommunityContext(param string) error { // SetLeftMenu sets the current topmost left menu name value. func (c *amContext) SetLeftMenu(name string) { - c.session.Values["leftMenu"] = name + AmSessionPut(c.session, "leftMenu", name) } /* SetLoginCookie adds the login cookie to the result output. @@ -416,17 +436,18 @@ func (c *amContext) SetScratch(name string, val any) { // GetSession returns a session variable. func (c *amContext) GetSession(name string) any { - return c.session.Values["x."+name] + rc, _ := AmSessionGet(c.session, "x."+name) + return rc } // SetSession sets a session variable. func (c *amContext) SetSession(name string, value any) { - c.session.Values["x."+name] = value + AmSessionPut(c.session, "x."+name, value) } // IsSession tests to see whether a session value is set. func (c *amContext) IsSession(name string) bool { - _, ok := c.session.Values["x."+name] + _, ok := AmSessionGet(c.session, "x."+name) return ok } @@ -507,15 +528,21 @@ func newContext(ctxt echo.Context) (*amContext, error) { AmHitSession(sess) } } - rc.user, err = database.AmGetUser(ctxt.Request().Context(), AmSessionUid(sess)) - if err == nil { - rc.effectiveLevel = rc.user.BaseLevel + id, ok := AmSessionUid(sess) + if ok { + rc.user, err = database.AmGetUser(ctxt.Request().Context(), id) + if err == nil { + rc.effectiveLevel = rc.user.BaseLevel + } else { + rc.user = nil + rc.effectiveLevel = database.AmRole("NotInList").Level() + } } else { rc.user = nil rc.effectiveLevel = database.AmRole("NotInList").Level() } - if !rc.user.IsAnon { - cp, ok := sess.Values["lastCommunity"] + if rc.user != nil && !rc.user.IsAnon { + cp, ok := AmSessionGet(sess, "lastCommunity") if ok { rc.SetCommunityContext(fmt.Sprintf("%d", cp)) } @@ -532,8 +559,11 @@ func newContext(ctxt echo.Context) (*amContext, error) { func AmContextFromEchoContext(ctxt echo.Context) AmContext { myctxt := ctxt.Get("__amsterdam_context") if myctxt != nil { - rc, ok := myctxt.(AmContext) + rc, ok := myctxt.(*amContext) if ok { + if rc.echoContext == nil { + rc.echoContext = ctxt + } return rc } } diff --git a/ui/session_mgr.go b/ui/session_mgr.go index 594c7b1..6913349 100644 --- a/ui/session_mgr.go +++ b/ui/session_mgr.go @@ -27,6 +27,39 @@ import ( log "github.com/sirupsen/logrus" ) +func AmSessionGet(sess *sessions.Session, key any) (any, bool) { + if sess == nil { + return 0, false + } + mtx := sess.Values["_mutex"].(*sync.RWMutex) + mtx.RLock() + defer mtx.RUnlock() + rc, ok := sess.Values[key] + return rc, ok +} + +func AmSessionPut(sess *sessions.Session, key, value any) { + if sess != nil { + mtx := sess.Values["_mutex"].(*sync.RWMutex) + mtx.Lock() + defer mtx.Unlock() + sess.Values[key] = value + } +} + +func AmSessionErase(sess *sessions.Session) { + if sess != nil { + mtx := sess.Values["_mutex"].(*sync.RWMutex) + mtx.Lock() + defer mtx.Unlock() + for k := range sess.Values { + if k != "_mutex" { + delete(sess.Values, k) + } + } + } +} + // AmsterdamStore is our implewmentation of the Gorilla session store that works close to HttpSession in Java. type AmsterdamStore struct { mutex sync.RWMutex @@ -81,6 +114,7 @@ func (st *AmsterdamStore) Get(r *http.Request, name string) (*sessions.Session, func (st *AmsterdamStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(st, name) session.IsNew = true + session.Values["_mutex"] = new(sync.RWMutex) idBytes := make([]byte, 32) if _, err := rand.Read(idBytes); err != nil { return nil, err @@ -121,7 +155,7 @@ func (st *AmsterdamStore) sweep(tick <-chan time.Time, done chan bool) { st.mutex.RLock() zap := make([]string, 0, len(st.sessions)) for k, v := range st.sessions { - lastTime, ok := v.Values["lasthit"] + lastTime, ok := AmSessionGet(v, "lasthit") if ok && time.Since(lastTime.(time.Time)) > st.expiry { zap = append(zap, k) } @@ -134,9 +168,7 @@ func (st *AmsterdamStore) sweep(tick <-chan time.Time, done chan bool) { s, ok := st.sessions[k] if ok { delete(st.sessions, k) - for q := range s.Values { - delete(s.Values, q) - } + AmSessionErase(s) } st.mutex.Unlock() } @@ -153,10 +185,12 @@ func (st *AmsterdamStore) sessionInfo() (int, []string, int) { users := make([]string, 0, len(st.sessions)) st.mutex.RLock() for _, s := range st.sessions { - if s.Values["user_anon"].(bool) { + v, ok := AmSessionGet(s, "user_anon") + if ok && v.(bool) { anons++ } else { - users = append(users, s.Values["user_name"].(string)) + name, _ := AmSessionGet(s, "user_name") + users = append(users, name.(string)) } } st.mutex.RUnlock() @@ -205,8 +239,13 @@ func SetupSessionManager() func() { } // AmSessionUid returns the current user ID of the session. -func AmSessionUid(session *sessions.Session) int32 { - return session.Values["user_id"].(int32) +func AmSessionUid(session *sessions.Session) (int32, bool) { + rc, ok := AmSessionGet(session, "user_id") + if ok { + return rc.(int32), ok + } else { + return -1, ok + } } /* AmSetSessionUser sets the user for the session. @@ -215,9 +254,9 @@ func AmSessionUid(session *sessions.Session) int32 { * user - The user to be associated with the session. */ func AmSetSessionUser(session *sessions.Session, user *database.User) { - session.Values["user_id"] = user.Uid - session.Values["user_name"] = user.Username - session.Values["user_anon"] = user.IsAnon + AmSessionPut(session, "user_id", user.Uid) + AmSessionPut(session, "user_name", user.Username) + AmSessionPut(session, "user_anon", user.IsAnon) } // setSessionAnon sets the user for the session to the anonymous user. @@ -236,25 +275,23 @@ var lastHitMutex sync.Mutex func AmSessionFirstTime(ctx context.Context, session *sessions.Session) { lastHitMutex.Lock() setSessionAnon(ctx, session) - session.Values["lasthit"] = time.Now() + AmSessionPut(session, "lasthit", time.Now()) lastHitMutex.Unlock() } // AmResetSession clears the specified session. func AmResetSession(ctx context.Context, session *sessions.Session) { lastHitMutex.Lock() - for k := range session.Values { - delete(session.Values, k) - } + AmSessionErase(session) setSessionAnon(ctx, session) - session.Values["lasthit"] = time.Now() + AmSessionPut(session, "lasthit", time.Now()) lastHitMutex.Unlock() } // AmHitSession "hits" a session, updating its "last hit" time. func AmHitSession(session *sessions.Session) { lastHitMutex.Lock() - session.Values["lasthit"] = time.Now() + AmSessionPut(session, "lasthit", time.Now()) lastHitMutex.Unlock() }