diff --git a/server/channels/store/sqlstore/session_store.go b/server/channels/store/sqlstore/session_store.go index 10da6609e32..18295d51f11 100644 --- a/server/channels/store/sqlstore/session_store.go +++ b/server/channels/store/sqlstore/session_store.go @@ -22,10 +22,20 @@ const ( type SqlSessionStore struct { *SqlStore + + sessionSelectQuery sq.SelectBuilder } func newSqlSessionStore(sqlStore *SqlStore) store.SessionStore { - return &SqlSessionStore{sqlStore} + s := &SqlSessionStore{ + SqlStore: sqlStore, + } + + s.sessionSelectQuery = s.getQueryBuilder(). + Select("Id", "Token", "CreateAt", "ExpiresAt", "LastActivityAt", "UserId", "DeviceId", "Roles", "IsOAuth", "ExpiredNotify", "Props"). + From("Sessions") + + return s } func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Session, error) { @@ -77,7 +87,20 @@ func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Se func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Session, error) { sessions := []*model.Session{} - if err := me.DBXFromContext(c.Context()).Select(&sessions, "SELECT * FROM Sessions WHERE Token = ? OR Id = ? LIMIT 1", sessionIdOrToken, sessionIdOrToken); err != nil { + query := me.sessionSelectQuery. + Where(sq.Or{ + sq.Eq{"Token": sessionIdOrToken}, + sq.Eq{"Id": sessionIdOrToken}, + }). + Limit(1) + + sql, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "session_get_tosql") + } + + err = me.DBXFromContext(c.Context()).Select(&sessions, sql, args...) + if err != nil { return nil, errors.Wrapf(err, "failed to find Sessions with sessionIdOrToken=%s", sessionIdOrToken) } if len(sessions) == 0 { @@ -103,7 +126,17 @@ func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Se func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Session, error) { sessions := []*model.Session{} - if err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE UserId = ? ORDER BY LastActivityAt DESC", userId); err != nil { + query := me.sessionSelectQuery. + Where(sq.Eq{"UserId": userId}). + OrderBy("LastActivityAt DESC") + + sql, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "session_get_sessions_tosql") + } + + err = me.GetReplica().Select(&sessions, sql, args...) + if err != nil { return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) } @@ -126,9 +159,7 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se // GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset // are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query). func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) { - builder := me.getQueryBuilder(). - Select("*"). - From("Sessions"). + builder := me.sessionSelectQuery. Where(sq.Eq{"UserId": userId}). OrderBy("LastActivityAt DESC"). Limit(limit). @@ -146,24 +177,25 @@ func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uin } func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) { - lastRemovedQuery := `DeviceId != COALESCE(Props->>'last_removed_device_id', '')` + now := model.GetMillis() + + // Start with the base query + builder := me.sessionSelectQuery. + Where(sq.Eq{"UserId": userId}). + Where(sq.NotEq{"ExpiresAt": 0}). + Where(sq.GtOrEq{"ExpiresAt": now}). + Where(sq.NotEq{"DeviceId": ""}) + + // Add the last_removed_device_id condition based on the driver if me.DriverName() == model.DatabaseDriverMysql { - lastRemovedQuery = `DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')` + builder = builder.Where("DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')") + } else { + builder = builder.Where("DeviceId != COALESCE(Props->>'last_removed_device_id', '')") } - query := - `SELECT * - FROM - Sessions - WHERE - UserId = ? AND - ExpiresAt != 0 AND - ? <= ExpiresAt AND - DeviceId != '' AND - ` + lastRemovedQuery sessions := []*model.Session{} - if err := me.GetReplica().Select(&sessions, query, userId, model.GetMillis()); err != nil { + if err := me.GetReplica().SelectBuilder(&sessions, builder); err != nil { return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) } return sessions, nil @@ -203,9 +235,7 @@ func (me SqlSessionStore) GetMobileSessionMetadata() ([]*model.MobileSessionMeta func (me SqlSessionStore) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) { now := model.GetMillis() - builder := me.getQueryBuilder(). - Select("*"). - From("Sessions"). + builder := me.sessionSelectQuery. Where(sq.NotEq{"ExpiresAt": 0}). Where(sq.Lt{"ExpiresAt": now}). Where(sq.Gt{"ExpiresAt": now - thresholdMillis}) diff --git a/server/channels/store/storetest/session_store.go b/server/channels/store/storetest/session_store.go index c1042d26228..6d3888b0cd8 100644 --- a/server/channels/store/storetest/session_store.go +++ b/server/channels/store/storetest/session_store.go @@ -38,6 +38,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) }) t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) }) t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) }) + t.Run("GetSessionsWithActiveDeviceIds", func(t *testing.T) { testGetSessionsWithActiveDeviceIds(t, rctx, ss) }) t.Run("GetMobileSessionMetadata", func(t *testing.T) { testGetMobileSessionMetadata(t, rctx, ss) }) } @@ -397,6 +398,71 @@ func testGetSessionsExpired(t *testing.T, rctx request.CTX, ss store.Store) { } } +func testGetSessionsWithActiveDeviceIds(t *testing.T, rctx request.CTX, ss store.Store) { + userId := model.NewId() + + // Create session 1 with a device ID + s1 := &model.Session{} + s1.UserId = userId + s1.ExpiresAt = model.GetMillis() + 100000 + s1.DeviceId = model.NewId() + s1, err := ss.Session().Save(rctx, s1) + require.NoError(t, err) + + // Create session 2 with a device ID and a prop for last_removed_device_id that doesn't match the device ID + s2 := &model.Session{} + s2.UserId = userId + s2.ExpiresAt = model.GetMillis() + 100000 + s2.DeviceId = model.NewId() + s2.AddProp(model.SessionPropLastRemovedDeviceId, model.NewId()) + s2, err = ss.Session().Save(rctx, s2) + require.NoError(t, err) + + // Create session 3 with a device ID and a prop for last_removed_device_id that matches the device ID - this should be filtered out + s3 := &model.Session{} + s3.UserId = userId + s3.ExpiresAt = model.GetMillis() + 100000 + s3.DeviceId = model.NewId() + s3.AddProp(model.SessionPropLastRemovedDeviceId, s3.DeviceId) + s3, err = ss.Session().Save(rctx, s3) + require.NoError(t, err) + + // Create session 4 with no device ID - this should be filtered out + s4 := &model.Session{} + s4.UserId = userId + s4.ExpiresAt = model.GetMillis() + 100000 + s4, err = ss.Session().Save(rctx, s4) + require.NoError(t, err) + + // Create session 5 with a device ID but expired - this should be filtered out + s5 := &model.Session{} + s5.UserId = userId + s5.ExpiresAt = model.GetMillis() - 100000 + s5.DeviceId = model.NewId() + s5, err = ss.Session().Save(rctx, s5) + require.NoError(t, err) + + // Get sessions with active device IDs + sessions, err := ss.Session().GetSessionsWithActiveDeviceIds(userId) + require.NoError(t, err) + + // We should have 2 sessions (s1 and s2) + require.Len(t, sessions, 2) + + // Verify s1 and s2 are in the result + sessionIds := make(map[string]bool) + for _, session := range sessions { + sessionIds[session.Id] = true + } + require.True(t, sessionIds[s1.Id]) + require.True(t, sessionIds[s2.Id]) + + // Verify s3, s4, and s5 are not in the result + require.False(t, sessionIds[s3.Id]) + require.False(t, sessionIds[s4.Id]) + require.False(t, sessionIds[s5.Id]) +} + func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) { s1 := &model.Session{} s1.UserId = model.NewId()