diff --git a/server/channels/store/sqlstore/user_access_token_store.go b/server/channels/store/sqlstore/user_access_token_store.go index 1b2db40ca31..1a3ca78b8b9 100644 --- a/server/channels/store/sqlstore/user_access_token_store.go +++ b/server/channels/store/sqlstore/user_access_token_store.go @@ -7,6 +7,7 @@ import ( "database/sql" "fmt" + sq "github.com/mattermost/squirrel" "github.com/pkg/errors" "github.com/mattermost/mattermost/server/public/model" @@ -15,10 +16,26 @@ import ( type SqlUserAccessTokenStore struct { *SqlStore + + userAccessTokensSelectQuery sq.SelectBuilder } func newSqlUserAccessTokenStore(sqlStore *SqlStore) store.UserAccessTokenStore { - return &SqlUserAccessTokenStore{sqlStore} + s := &SqlUserAccessTokenStore{ + SqlStore: sqlStore, + } + + s.userAccessTokensSelectQuery = s.getQueryBuilder(). + Select( + "UserAccessTokens.Id", + "UserAccessTokens.Token", + "UserAccessTokens.UserId", + "UserAccessTokens.Description", + "UserAccessTokens.IsActive", + ). + From("UserAccessTokens") + + return s } func (s SqlUserAccessTokenStore) Save(token *model.UserAccessToken) (*model.UserAccessToken, error) { @@ -125,7 +142,9 @@ func (s SqlUserAccessTokenStore) deleteTokensByUser(transaction *sqlxTxWrapper, func (s SqlUserAccessTokenStore) Get(tokenId string) (*model.UserAccessToken, error) { var token model.UserAccessToken - if err := s.GetReplica().Get(&token, "SELECT * FROM UserAccessTokens WHERE Id = ?", tokenId); err != nil { + query := s.userAccessTokensSelectQuery.Where(sq.Eq{"Id": tokenId}) + + if err := s.GetReplica().GetBuilder(&token, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("UserAccessToken", tokenId) } @@ -138,7 +157,11 @@ func (s SqlUserAccessTokenStore) Get(tokenId string) (*model.UserAccessToken, er func (s SqlUserAccessTokenStore) GetAll(offset, limit int) ([]*model.UserAccessToken, error) { tokens := []*model.UserAccessToken{} - if err := s.GetReplica().Select(&tokens, "SELECT * FROM UserAccessTokens LIMIT ? OFFSET ?", limit, offset); err != nil { + query := s.userAccessTokensSelectQuery. + Limit(uint64(limit)). + Offset(uint64(offset)) + + if err := s.GetReplica().SelectBuilder(&tokens, query); err != nil { return nil, errors.Wrap(err, "failed to find UserAccessTokens") } @@ -148,7 +171,9 @@ func (s SqlUserAccessTokenStore) GetAll(offset, limit int) ([]*model.UserAccessT func (s SqlUserAccessTokenStore) GetByToken(tokenString string) (*model.UserAccessToken, error) { var token model.UserAccessToken - if err := s.GetReplica().Get(&token, "SELECT * FROM UserAccessTokens WHERE Token = ?", tokenString); err != nil { + query := s.userAccessTokensSelectQuery.Where(sq.Eq{"Token": tokenString}) + + if err := s.GetReplica().GetBuilder(&token, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("UserAccessToken", fmt.Sprintf("token=%s", tokenString)) } @@ -161,7 +186,12 @@ func (s SqlUserAccessTokenStore) GetByToken(tokenString string) (*model.UserAcce func (s SqlUserAccessTokenStore) GetByUser(userId string, offset, limit int) ([]*model.UserAccessToken, error) { tokens := []*model.UserAccessToken{} - if err := s.GetReplica().Select(&tokens, "SELECT * FROM UserAccessTokens WHERE UserId = ? LIMIT ? OFFSET ?", userId, limit, offset); err != nil { + query := s.userAccessTokensSelectQuery. + Where(sq.Eq{"UserId": userId}). + Limit(uint64(limit)). + Offset(uint64(offset)) + + if err := s.GetReplica().SelectBuilder(&tokens, query); err != nil { return nil, errors.Wrapf(err, "failed to find UserAccessTokens with userId=%s", userId) } @@ -171,16 +201,16 @@ func (s SqlUserAccessTokenStore) GetByUser(userId string, offset, limit int) ([] func (s SqlUserAccessTokenStore) Search(term string) ([]*model.UserAccessToken, error) { term = sanitizeSearchTerm(term, "\\") tokens := []*model.UserAccessToken{} - params := []any{term, term, term} - query := ` - SELECT - uat.* - FROM UserAccessTokens uat - INNER JOIN Users u - ON uat.UserId = u.Id - WHERE uat.Id LIKE ? OR uat.UserId LIKE ? OR u.Username LIKE ?` - if err := s.GetReplica().Select(&tokens, query, params...); err != nil { + query := s.userAccessTokensSelectQuery. + InnerJoin("Users ON UserAccessTokens.UserId = Users.Id"). + Where(sq.Or{ + sq.Like{"UserAccessTokens.Id": term}, + sq.Like{"UserAccessTokens.UserId": term}, + sq.Like{"Users.Username": term}, + }) + + if err := s.GetReplica().SelectBuilder(&tokens, query); err != nil { return nil, errors.Wrapf(err, "failed to find UserAccessTokens by term with value '%s'", term) } diff --git a/server/channels/store/storetest/user_access_token_store.go b/server/channels/store/storetest/user_access_token_store.go index 83bdc941ad7..2037ec1e4e5 100644 --- a/server/channels/store/storetest/user_access_token_store.go +++ b/server/channels/store/storetest/user_access_token_store.go @@ -17,6 +17,7 @@ func TestUserAccessTokenStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("UserAccessTokenSaveGetDelete", func(t *testing.T) { testUserAccessTokenSaveGetDelete(t, rctx, ss) }) t.Run("UserAccessTokenDisableEnable", func(t *testing.T) { testUserAccessTokenDisableEnable(t, rctx, ss) }) t.Run("UserAccessTokenSearch", func(t *testing.T) { testUserAccessTokenSearch(t, rctx, ss) }) + t.Run("UserAccessTokenPagination", func(t *testing.T) { testUserAccessTokenPagination(t, rctx, ss) }) } func testUserAccessTokenSaveGetDelete(t *testing.T, rctx request.CTX, ss store.Store) { @@ -155,3 +156,92 @@ func testUserAccessTokenSearch(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, nErr) require.Equal(t, 1, len(received), "received incorrect number of tokens after search") } + +func testUserAccessTokenPagination(t *testing.T, rctx request.CTX, ss store.Store) { + // Create a user + u1 := model.User{} + u1.Email = MakeEmail() + u1.Username = model.NewUsername() + + _, err := ss.User().Save(rctx, &u1) + require.NoError(t, err) + + // Create 10 tokens for the user + tokens := make([]*model.UserAccessToken, 10) + for i := 0; i < 10; i++ { + tokens[i] = &model.UserAccessToken{ + Token: model.NewId(), + UserId: u1.Id, + Description: "testtoken" + model.NewId(), + } + + // Create a session for each token + s := &model.Session{} + s.UserId = tokens[i].UserId + s.Token = tokens[i].Token + + _, err = ss.Session().Save(rctx, s) + require.NoError(t, err) + + // Save the token + _, nErr := ss.UserAccessToken().Save(tokens[i]) + require.NoError(t, nErr) + } + + // Set up cleanup to run even if the test fails + t.Cleanup(func() { + for _, token := range tokens { + // Cleanup shouldn't fail the test, but we still log errors + if err := ss.UserAccessToken().Delete(token.Id); err != nil { + t.Logf("Failed to cleanup token %s: %v", token.Id, err) + } + } + }) + + // Test GetAll with pagination + // First page (3 tokens) + result, nErr := ss.UserAccessToken().GetAll(0, 3) + require.NoError(t, nErr) + require.Len(t, result, 3, "Should return 3 tokens for the first page") + + // Second page (3 tokens) + result, nErr = ss.UserAccessToken().GetAll(3, 3) + require.NoError(t, nErr) + require.Len(t, result, 3, "Should return 3 tokens for the second page") + + // Beyond the total number of tokens + result, nErr = ss.UserAccessToken().GetAll(30, 3) + require.NoError(t, nErr) + require.Len(t, result, 0, "Should return 0 tokens when offset is beyond total") + + // All tokens + result, nErr = ss.UserAccessToken().GetAll(0, 100) + require.NoError(t, nErr) + require.GreaterOrEqual(t, len(result), 10, "Should return at least 10 tokens") + + // Test GetByUser with pagination + // First page (3 tokens) + result, nErr = ss.UserAccessToken().GetByUser(u1.Id, 0, 3) + require.NoError(t, nErr) + require.Len(t, result, 3, "Should return 3 tokens for the first page") + + // Second page (3 tokens) + result, nErr = ss.UserAccessToken().GetByUser(u1.Id, 3, 3) + require.NoError(t, nErr) + require.Len(t, result, 3, "Should return 3 tokens for the second page") + + // Beyond the total number of tokens + result, nErr = ss.UserAccessToken().GetByUser(u1.Id, 30, 3) + require.NoError(t, nErr) + require.Len(t, result, 0, "Should return 0 tokens when offset is beyond total") + + // All tokens for the user + result, nErr = ss.UserAccessToken().GetByUser(u1.Id, 0, 100) + require.NoError(t, nErr) + require.Len(t, result, 10, "Should return 10 tokens for the user") + + // Test for a non-existent user + result, nErr = ss.UserAccessToken().GetByUser(model.NewId(), 0, 100) + require.NoError(t, nErr) + require.Len(t, result, 0, "Should return 0 tokens for non-existent user") +}