MM-62161: Replace SELECT * in session_store.go (#30422)

* MM-62161: Replace SELECT * in session_store.go

- Replaced all SELECT * queries with explicit column selection
- Used query builder instead of raw SQL strings where possible
- Added reusable sessionSelectQuery in the store constructor
- Added comprehensive test for GetSessionsWithActiveDeviceIds

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix variable shadowing issues in session_store.go

Resolved variable shadowing by reassigning to the existing error variables instead of declaring new ones in scoped blocks.

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix code formatting in session_store test file

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Jesse Hallam 2025-03-12 16:35:33 -03:00 committed by GitHub
parent 8a21016266
commit 9fc83f24b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 118 additions and 22 deletions

View file

@ -22,10 +22,20 @@ const (
type SqlSessionStore struct {
*SqlStore
sessionSelectQuery sq.SelectBuilder
}
func newSqlSessionStore(sqlStore *SqlStore) store.SessionStore {
return &SqlSessionStore{sqlStore}
s := &SqlSessionStore{
SqlStore: sqlStore,
}
s.sessionSelectQuery = s.getQueryBuilder().
Select("Id", "Token", "CreateAt", "ExpiresAt", "LastActivityAt", "UserId", "DeviceId", "Roles", "IsOAuth", "ExpiredNotify", "Props").
From("Sessions")
return s
}
func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Session, error) {
@ -77,7 +87,20 @@ func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Se
func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Session, error) {
sessions := []*model.Session{}
if err := me.DBXFromContext(c.Context()).Select(&sessions, "SELECT * FROM Sessions WHERE Token = ? OR Id = ? LIMIT 1", sessionIdOrToken, sessionIdOrToken); err != nil {
query := me.sessionSelectQuery.
Where(sq.Or{
sq.Eq{"Token": sessionIdOrToken},
sq.Eq{"Id": sessionIdOrToken},
}).
Limit(1)
sql, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "session_get_tosql")
}
err = me.DBXFromContext(c.Context()).Select(&sessions, sql, args...)
if err != nil {
return nil, errors.Wrapf(err, "failed to find Sessions with sessionIdOrToken=%s", sessionIdOrToken)
}
if len(sessions) == 0 {
@ -103,7 +126,17 @@ func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Se
func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Session, error) {
sessions := []*model.Session{}
if err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE UserId = ? ORDER BY LastActivityAt DESC", userId); err != nil {
query := me.sessionSelectQuery.
Where(sq.Eq{"UserId": userId}).
OrderBy("LastActivityAt DESC")
sql, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "session_get_sessions_tosql")
}
err = me.GetReplica().Select(&sessions, sql, args...)
if err != nil {
return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId)
}
@ -126,9 +159,7 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se
// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset
// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query).
func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) {
builder := me.getQueryBuilder().
Select("*").
From("Sessions").
builder := me.sessionSelectQuery.
Where(sq.Eq{"UserId": userId}).
OrderBy("LastActivityAt DESC").
Limit(limit).
@ -146,24 +177,25 @@ func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uin
}
func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) {
lastRemovedQuery := `DeviceId != COALESCE(Props->>'last_removed_device_id', '')`
now := model.GetMillis()
// Start with the base query
builder := me.sessionSelectQuery.
Where(sq.Eq{"UserId": userId}).
Where(sq.NotEq{"ExpiresAt": 0}).
Where(sq.GtOrEq{"ExpiresAt": now}).
Where(sq.NotEq{"DeviceId": ""})
// Add the last_removed_device_id condition based on the driver
if me.DriverName() == model.DatabaseDriverMysql {
lastRemovedQuery = `DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')`
builder = builder.Where("DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')")
} else {
builder = builder.Where("DeviceId != COALESCE(Props->>'last_removed_device_id', '')")
}
query :=
`SELECT *
FROM
Sessions
WHERE
UserId = ? AND
ExpiresAt != 0 AND
? <= ExpiresAt AND
DeviceId != '' AND
` + lastRemovedQuery
sessions := []*model.Session{}
if err := me.GetReplica().Select(&sessions, query, userId, model.GetMillis()); err != nil {
if err := me.GetReplica().SelectBuilder(&sessions, builder); err != nil {
return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId)
}
return sessions, nil
@ -203,9 +235,7 @@ func (me SqlSessionStore) GetMobileSessionMetadata() ([]*model.MobileSessionMeta
func (me SqlSessionStore) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) {
now := model.GetMillis()
builder := me.getQueryBuilder().
Select("*").
From("Sessions").
builder := me.sessionSelectQuery.
Where(sq.NotEq{"ExpiresAt": 0}).
Where(sq.Lt{"ExpiresAt": now}).
Where(sq.Gt{"ExpiresAt": now - thresholdMillis})

View file

@ -38,6 +38,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) })
t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) })
t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) })
t.Run("GetSessionsWithActiveDeviceIds", func(t *testing.T) { testGetSessionsWithActiveDeviceIds(t, rctx, ss) })
t.Run("GetMobileSessionMetadata", func(t *testing.T) { testGetMobileSessionMetadata(t, rctx, ss) })
}
@ -397,6 +398,71 @@ func testGetSessionsExpired(t *testing.T, rctx request.CTX, ss store.Store) {
}
}
func testGetSessionsWithActiveDeviceIds(t *testing.T, rctx request.CTX, ss store.Store) {
userId := model.NewId()
// Create session 1 with a device ID
s1 := &model.Session{}
s1.UserId = userId
s1.ExpiresAt = model.GetMillis() + 100000
s1.DeviceId = model.NewId()
s1, err := ss.Session().Save(rctx, s1)
require.NoError(t, err)
// Create session 2 with a device ID and a prop for last_removed_device_id that doesn't match the device ID
s2 := &model.Session{}
s2.UserId = userId
s2.ExpiresAt = model.GetMillis() + 100000
s2.DeviceId = model.NewId()
s2.AddProp(model.SessionPropLastRemovedDeviceId, model.NewId())
s2, err = ss.Session().Save(rctx, s2)
require.NoError(t, err)
// Create session 3 with a device ID and a prop for last_removed_device_id that matches the device ID - this should be filtered out
s3 := &model.Session{}
s3.UserId = userId
s3.ExpiresAt = model.GetMillis() + 100000
s3.DeviceId = model.NewId()
s3.AddProp(model.SessionPropLastRemovedDeviceId, s3.DeviceId)
s3, err = ss.Session().Save(rctx, s3)
require.NoError(t, err)
// Create session 4 with no device ID - this should be filtered out
s4 := &model.Session{}
s4.UserId = userId
s4.ExpiresAt = model.GetMillis() + 100000
s4, err = ss.Session().Save(rctx, s4)
require.NoError(t, err)
// Create session 5 with a device ID but expired - this should be filtered out
s5 := &model.Session{}
s5.UserId = userId
s5.ExpiresAt = model.GetMillis() - 100000
s5.DeviceId = model.NewId()
s5, err = ss.Session().Save(rctx, s5)
require.NoError(t, err)
// Get sessions with active device IDs
sessions, err := ss.Session().GetSessionsWithActiveDeviceIds(userId)
require.NoError(t, err)
// We should have 2 sessions (s1 and s2)
require.Len(t, sessions, 2)
// Verify s1 and s2 are in the result
sessionIds := make(map[string]bool)
for _, session := range sessions {
sessionIds[session.Id] = true
}
require.True(t, sessionIds[s1.Id])
require.True(t, sessionIds[s2.Id])
// Verify s3, s4, and s5 are not in the result
require.False(t, sessionIds[s3.Id])
require.False(t, sessionIds[s4.Id])
require.False(t, sessionIds[s5.Id])
}
func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) {
s1 := &model.Session{}
s1.UserId = model.NewId()