mirror of
https://github.com/mattermost/mattermost.git
synced 2026-05-28 04:35:04 -04:00
MM-62161: Replace SELECT * in session_store.go (#30422)
* MM-62161: Replace SELECT * in session_store.go - Replaced all SELECT * queries with explicit column selection - Used query builder instead of raw SQL strings where possible - Added reusable sessionSelectQuery in the store constructor - Added comprehensive test for GetSessionsWithActiveDeviceIds 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix variable shadowing issues in session_store.go Resolved variable shadowing by reassigning to the existing error variables instead of declaring new ones in scoped blocks. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix code formatting in session_store test file 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
8a21016266
commit
9fc83f24b5
2 changed files with 118 additions and 22 deletions
|
|
@ -22,10 +22,20 @@ const (
|
|||
|
||||
type SqlSessionStore struct {
|
||||
*SqlStore
|
||||
|
||||
sessionSelectQuery sq.SelectBuilder
|
||||
}
|
||||
|
||||
func newSqlSessionStore(sqlStore *SqlStore) store.SessionStore {
|
||||
return &SqlSessionStore{sqlStore}
|
||||
s := &SqlSessionStore{
|
||||
SqlStore: sqlStore,
|
||||
}
|
||||
|
||||
s.sessionSelectQuery = s.getQueryBuilder().
|
||||
Select("Id", "Token", "CreateAt", "ExpiresAt", "LastActivityAt", "UserId", "DeviceId", "Roles", "IsOAuth", "ExpiredNotify", "Props").
|
||||
From("Sessions")
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Session, error) {
|
||||
|
|
@ -77,7 +87,20 @@ func (me SqlSessionStore) Save(c request.CTX, session *model.Session) (*model.Se
|
|||
func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Session, error) {
|
||||
sessions := []*model.Session{}
|
||||
|
||||
if err := me.DBXFromContext(c.Context()).Select(&sessions, "SELECT * FROM Sessions WHERE Token = ? OR Id = ? LIMIT 1", sessionIdOrToken, sessionIdOrToken); err != nil {
|
||||
query := me.sessionSelectQuery.
|
||||
Where(sq.Or{
|
||||
sq.Eq{"Token": sessionIdOrToken},
|
||||
sq.Eq{"Id": sessionIdOrToken},
|
||||
}).
|
||||
Limit(1)
|
||||
|
||||
sql, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "session_get_tosql")
|
||||
}
|
||||
|
||||
err = me.DBXFromContext(c.Context()).Select(&sessions, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find Sessions with sessionIdOrToken=%s", sessionIdOrToken)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
|
|
@ -103,7 +126,17 @@ func (me SqlSessionStore) Get(c request.CTX, sessionIdOrToken string) (*model.Se
|
|||
func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Session, error) {
|
||||
sessions := []*model.Session{}
|
||||
|
||||
if err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE UserId = ? ORDER BY LastActivityAt DESC", userId); err != nil {
|
||||
query := me.sessionSelectQuery.
|
||||
Where(sq.Eq{"UserId": userId}).
|
||||
OrderBy("LastActivityAt DESC")
|
||||
|
||||
sql, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "session_get_sessions_tosql")
|
||||
}
|
||||
|
||||
err = me.GetReplica().Select(&sessions, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId)
|
||||
}
|
||||
|
||||
|
|
@ -126,9 +159,7 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se
|
|||
// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset
|
||||
// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query).
|
||||
func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) {
|
||||
builder := me.getQueryBuilder().
|
||||
Select("*").
|
||||
From("Sessions").
|
||||
builder := me.sessionSelectQuery.
|
||||
Where(sq.Eq{"UserId": userId}).
|
||||
OrderBy("LastActivityAt DESC").
|
||||
Limit(limit).
|
||||
|
|
@ -146,24 +177,25 @@ func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uin
|
|||
}
|
||||
|
||||
func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) {
|
||||
lastRemovedQuery := `DeviceId != COALESCE(Props->>'last_removed_device_id', '')`
|
||||
now := model.GetMillis()
|
||||
|
||||
// Start with the base query
|
||||
builder := me.sessionSelectQuery.
|
||||
Where(sq.Eq{"UserId": userId}).
|
||||
Where(sq.NotEq{"ExpiresAt": 0}).
|
||||
Where(sq.GtOrEq{"ExpiresAt": now}).
|
||||
Where(sq.NotEq{"DeviceId": ""})
|
||||
|
||||
// Add the last_removed_device_id condition based on the driver
|
||||
if me.DriverName() == model.DatabaseDriverMysql {
|
||||
lastRemovedQuery = `DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')`
|
||||
builder = builder.Where("DeviceId != COALESCE(Props->>'$.last_removed_device_id', '')")
|
||||
} else {
|
||||
builder = builder.Where("DeviceId != COALESCE(Props->>'last_removed_device_id', '')")
|
||||
}
|
||||
query :=
|
||||
`SELECT *
|
||||
FROM
|
||||
Sessions
|
||||
WHERE
|
||||
UserId = ? AND
|
||||
ExpiresAt != 0 AND
|
||||
? <= ExpiresAt AND
|
||||
DeviceId != '' AND
|
||||
` + lastRemovedQuery
|
||||
|
||||
sessions := []*model.Session{}
|
||||
|
||||
if err := me.GetReplica().Select(&sessions, query, userId, model.GetMillis()); err != nil {
|
||||
if err := me.GetReplica().SelectBuilder(&sessions, builder); err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId)
|
||||
}
|
||||
return sessions, nil
|
||||
|
|
@ -203,9 +235,7 @@ func (me SqlSessionStore) GetMobileSessionMetadata() ([]*model.MobileSessionMeta
|
|||
|
||||
func (me SqlSessionStore) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) {
|
||||
now := model.GetMillis()
|
||||
builder := me.getQueryBuilder().
|
||||
Select("*").
|
||||
From("Sessions").
|
||||
builder := me.sessionSelectQuery.
|
||||
Where(sq.NotEq{"ExpiresAt": 0}).
|
||||
Where(sq.Lt{"ExpiresAt": now}).
|
||||
Where(sq.Gt{"ExpiresAt": now - thresholdMillis})
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
|||
t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) })
|
||||
t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) })
|
||||
t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) })
|
||||
t.Run("GetSessionsWithActiveDeviceIds", func(t *testing.T) { testGetSessionsWithActiveDeviceIds(t, rctx, ss) })
|
||||
t.Run("GetMobileSessionMetadata", func(t *testing.T) { testGetMobileSessionMetadata(t, rctx, ss) })
|
||||
}
|
||||
|
||||
|
|
@ -397,6 +398,71 @@ func testGetSessionsExpired(t *testing.T, rctx request.CTX, ss store.Store) {
|
|||
}
|
||||
}
|
||||
|
||||
func testGetSessionsWithActiveDeviceIds(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
userId := model.NewId()
|
||||
|
||||
// Create session 1 with a device ID
|
||||
s1 := &model.Session{}
|
||||
s1.UserId = userId
|
||||
s1.ExpiresAt = model.GetMillis() + 100000
|
||||
s1.DeviceId = model.NewId()
|
||||
s1, err := ss.Session().Save(rctx, s1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create session 2 with a device ID and a prop for last_removed_device_id that doesn't match the device ID
|
||||
s2 := &model.Session{}
|
||||
s2.UserId = userId
|
||||
s2.ExpiresAt = model.GetMillis() + 100000
|
||||
s2.DeviceId = model.NewId()
|
||||
s2.AddProp(model.SessionPropLastRemovedDeviceId, model.NewId())
|
||||
s2, err = ss.Session().Save(rctx, s2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create session 3 with a device ID and a prop for last_removed_device_id that matches the device ID - this should be filtered out
|
||||
s3 := &model.Session{}
|
||||
s3.UserId = userId
|
||||
s3.ExpiresAt = model.GetMillis() + 100000
|
||||
s3.DeviceId = model.NewId()
|
||||
s3.AddProp(model.SessionPropLastRemovedDeviceId, s3.DeviceId)
|
||||
s3, err = ss.Session().Save(rctx, s3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create session 4 with no device ID - this should be filtered out
|
||||
s4 := &model.Session{}
|
||||
s4.UserId = userId
|
||||
s4.ExpiresAt = model.GetMillis() + 100000
|
||||
s4, err = ss.Session().Save(rctx, s4)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create session 5 with a device ID but expired - this should be filtered out
|
||||
s5 := &model.Session{}
|
||||
s5.UserId = userId
|
||||
s5.ExpiresAt = model.GetMillis() - 100000
|
||||
s5.DeviceId = model.NewId()
|
||||
s5, err = ss.Session().Save(rctx, s5)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get sessions with active device IDs
|
||||
sessions, err := ss.Session().GetSessionsWithActiveDeviceIds(userId)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should have 2 sessions (s1 and s2)
|
||||
require.Len(t, sessions, 2)
|
||||
|
||||
// Verify s1 and s2 are in the result
|
||||
sessionIds := make(map[string]bool)
|
||||
for _, session := range sessions {
|
||||
sessionIds[session.Id] = true
|
||||
}
|
||||
require.True(t, sessionIds[s1.Id])
|
||||
require.True(t, sessionIds[s2.Id])
|
||||
|
||||
// Verify s3, s4, and s5 are not in the result
|
||||
require.False(t, sessionIds[s3.Id])
|
||||
require.False(t, sessionIds[s4.Id])
|
||||
require.False(t, sessionIds[s5.Id])
|
||||
}
|
||||
|
||||
func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
s1 := &model.Session{}
|
||||
s1.UserId = model.NewId()
|
||||
|
|
|
|||
Loading…
Reference in a new issue