MM-62158: group store no SELECT * (Part 1) (#30276)

* improved test coverage

* initial pass on removing SELECT * from group store
This commit is contained in:
Jesse Hallam 2025-05-01 09:39:05 -03:00 committed by GitHub
parent 80c58a9742
commit e1f47e22e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 450 additions and 154 deletions

View file

@ -49,10 +49,67 @@ type groupChannelJoin struct {
type SqlGroupStore struct {
*SqlStore
userGroupsSelectQuery sq.SelectBuilder
groupMembersSelectQuery sq.SelectBuilder
groupMemberUsersSelectQuery sq.SelectBuilder
groupTeamsSelectQuery sq.SelectBuilder
groupChannelsSelectQuery sq.SelectBuilder
}
func newSqlGroupStore(sqlStore *SqlStore) store.GroupStore {
return &SqlGroupStore{SqlStore: sqlStore}
s := &SqlGroupStore{SqlStore: sqlStore}
s.userGroupsSelectQuery = s.getQueryBuilder().
Select(
"UserGroups.Id",
"UserGroups.Name",
"UserGroups.DisplayName",
"UserGroups.Description",
"UserGroups.Source",
"UserGroups.RemoteId",
"UserGroups.CreateAt",
"UserGroups.UpdateAt",
"UserGroups.DeleteAt",
"UserGroups.AllowReference",
).
From("UserGroups")
s.groupMembersSelectQuery = s.getQueryBuilder().
Select(
"GroupMembers.GroupId",
"GroupMembers.UserId",
"GroupMembers.CreateAt",
"GroupMembers.DeleteAt",
).From("GroupMembers")
s.groupMemberUsersSelectQuery = s.getQueryBuilder().
Select(getUsersColumns()...).
From("GroupMembers").
Join("Users ON Users.Id = GroupMembers.UserId")
s.groupTeamsSelectQuery = s.getQueryBuilder().
Select(
"GroupTeams.GroupId",
"GroupTeams.TeamId",
"GroupTeams.AutoAdd",
"GroupTeams.SchemeAdmin",
"GroupTeams.CreateAt",
"GroupTeams.UpdateAt",
"GroupTeams.DeleteAt",
).From("GroupTeams")
s.groupChannelsSelectQuery = s.getQueryBuilder().
Select(
"GroupChannels.GroupId",
"GroupChannels.ChannelId",
"GroupChannels.AutoAdd",
"GroupChannels.SchemeAdmin",
"GroupChannels.CreateAt",
"GroupChannels.UpdateAt",
"GroupChannels.DeleteAt",
).From("GroupChannels")
return s
}
func (s *SqlGroupStore) Create(group *model.Group) (*model.Group, error) {
@ -132,14 +189,9 @@ func (s *SqlGroupStore) CreateWithUserIds(g *model.GroupWithUserIds) (_ *model.G
return nil, err
}
// Get the new Group along with the member count
groupGroupQuery := `
SELECT
UserGroups.*,
A.Count AS MemberCount
FROM
UserGroups
INNER JOIN (
groupGroupQuery := s.userGroupsSelectQuery.
Column("A.Count AS MemberCount").
InnerJoin(`(
SELECT
UserGroups.Id,
COUNT(GroupMembers.UserId) AS Count
@ -154,12 +206,12 @@ func (s *SqlGroupStore) CreateWithUserIds(g *model.GroupWithUserIds) (_ *model.G
UserGroups.DisplayName,
UserGroups.Id
LIMIT
? OFFSET ?
) AS A ON UserGroups.Id = A.Id
ORDER BY
UserGroups.CreateAt DESC`
1 OFFSET 0
) AS A ON UserGroups.Id = A.Id`, g.Id).
OrderBy("UserGroups.CreateAt DESC")
var newGroup group
if err = txn.Get(&newGroup, groupGroupQuery, g.Id, 1, 0); err != nil {
if err = txn.GetBuilder(&newGroup, groupGroupQuery); err != nil {
return nil, err
}
if err = txn.Commit(); err != nil {
@ -172,16 +224,13 @@ func (s *SqlGroupStore) checkUsersExist(userIDs []string) error {
if len(userIDs) == 0 {
return nil
}
usersSelectQuery, usersSelectArgs, err := s.getQueryBuilder().
usersSelectQuery := s.getQueryBuilder().
Select("Id").
From("Users").
Where(sq.Eq{"Id": userIDs, "DeleteAt": 0}).
ToSql()
if err != nil {
return err
}
Where(sq.Eq{"Id": userIDs, "DeleteAt": 0})
var rows []string
err = s.GetReplica().Select(&rows, usersSelectQuery, usersSelectArgs...)
err := s.GetReplica().SelectBuilder(&rows, usersSelectQuery)
if err != nil {
return err
}
@ -215,9 +264,7 @@ func (s *SqlGroupStore) buildInsertGroupUsersQuery(groupId string, userIds []str
func (s *SqlGroupStore) Get(groupId string) (*model.Group, error) {
var group model.Group
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.Eq{"Id": groupId})
if err := s.GetReplica().GetBuilder(&group, builder); err != nil {
@ -232,16 +279,14 @@ func (s *SqlGroupStore) Get(groupId string) (*model.Group, error) {
func (s *SqlGroupStore) GetByName(name string, opts model.GroupSearchOpts) (*model.Group, error) {
var group model.Group
query := s.getQueryBuilder().Select("*").From("UserGroups").Where(sq.Eq{"Name": name})
query := s.userGroupsSelectQuery.
Where(sq.Eq{"Name": name})
if opts.FilterAllowReference {
query = query.Where("AllowReference = true")
}
queryString, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "get_by_name_tosql")
}
if err := s.GetReplica().Get(&group, queryString, args...); err != nil {
if err := s.GetReplica().GetBuilder(&group, query); err != nil {
if err == sql.ErrNoRows {
return nil, store.NewErrNotFound("Group", fmt.Sprintf("name=%s", name))
}
@ -253,12 +298,8 @@ func (s *SqlGroupStore) GetByName(name string, opts model.GroupSearchOpts) (*mod
func (s *SqlGroupStore) GetByIDs(groupIDs []string) ([]*model.Group, error) {
groups := []*model.Group{}
query := s.getQueryBuilder().Select("*").From("UserGroups").Where(sq.Eq{"Id": groupIDs})
queryString, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "get_by_ids_tosql")
}
if err := s.GetReplica().Select(&groups, queryString, args...); err != nil {
query := s.userGroupsSelectQuery.Where(sq.Eq{"Id": groupIDs})
if err := s.GetReplica().SelectBuilder(&groups, query); err != nil {
return nil, errors.Wrap(err, "failed to find Groups by ids")
}
return groups, nil
@ -266,9 +307,7 @@ func (s *SqlGroupStore) GetByIDs(groupIDs []string) ([]*model.Group, error) {
func (s *SqlGroupStore) GetByRemoteID(remoteID string, groupSource model.GroupSource) (*model.Group, error) {
var group model.Group
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.Eq{
"RemoteId": remoteID,
"Source": groupSource,
@ -286,9 +325,7 @@ func (s *SqlGroupStore) GetByRemoteID(remoteID string, groupSource model.GroupSo
func (s *SqlGroupStore) GetAllBySource(groupSource model.GroupSource) ([]*model.Group, error) {
groups := []*model.Group{}
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.Eq{
"DeleteAt": 0,
"Source": groupSource,
@ -304,13 +341,11 @@ func (s *SqlGroupStore) GetAllBySource(groupSource model.GroupSource) ([]*model.
func (s *SqlGroupStore) GetByUser(userID string, opts model.GroupSearchOpts) ([]*model.Group, error) {
groups := []*model.Group{}
builder := s.getQueryBuilder().
Select("UserGroups.*").
From("GroupMembers").
Join("UserGroups ON UserGroups.Id = GroupMembers.GroupId").
builder := s.userGroupsSelectQuery.
Join("GroupMembers ON GroupMembers.GroupId = UserGroups.Id").
Where(sq.Eq{
"GroupMembers.DeleteAt": 0,
"UserId": userID,
"GroupMembers.UserId": userID,
})
if opts.FilterAllowReference {
@ -326,10 +361,7 @@ func (s *SqlGroupStore) GetByUser(userID string, opts model.GroupSearchOpts) ([]
func (s *SqlGroupStore) Update(group *model.Group) (*model.Group, error) {
var retrievedGroup model.Group
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
Where(sq.Eq{"Id": group.Id})
builder := s.userGroupsSelectQuery.Where(sq.Eq{"Id": group.Id})
if err := s.GetReplica().GetBuilder(&retrievedGroup, builder); err != nil {
if err == sql.ErrNoRows {
@ -371,9 +403,7 @@ func (s *SqlGroupStore) Update(group *model.Group) (*model.Group, error) {
func (s *SqlGroupStore) Delete(groupID string) (*model.Group, error) {
var group model.Group
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.Eq{
"Id": groupID,
"DeleteAt": 0,
@ -400,9 +430,7 @@ func (s *SqlGroupStore) Delete(groupID string) (*model.Group, error) {
func (s *SqlGroupStore) Restore(groupID string) (*model.Group, error) {
var group model.Group
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.And{
sq.Eq{"Id": groupID},
sq.NotEq{"DeleteAt": 0},
@ -427,12 +455,11 @@ func (s *SqlGroupStore) Restore(groupID string) (*model.Group, error) {
}
func (s *SqlGroupStore) GetMember(groupID, userID string) (*model.GroupMember, error) {
builder := s.getQueryBuilder().
Select("*").
From("GroupMembers").
builder := s.groupMembersSelectQuery.
Where(sq.Eq{"UserId": userID}).
Where(sq.Eq{"GroupId": groupID}).
Where(sq.Eq{"DeleteAt": 0})
var groupMember model.GroupMember
if err := s.GetReplica().GetBuilder(&groupMember, builder); err != nil {
return nil, errors.Wrap(err, "GetMember")
@ -443,14 +470,11 @@ func (s *SqlGroupStore) GetMember(groupID, userID string) (*model.GroupMember, e
func (s *SqlGroupStore) GetMemberUsers(groupID string) ([]*model.User, error) {
groupMembers := []*model.User{}
builder := s.getQueryBuilder().
Select("Users.*").
From("GroupMembers").
Join("Users ON Users.Id = GroupMembers.UserId").
builder := s.groupMemberUsersSelectQuery.
Where(sq.Eq{
"GroupMembers.DeleteAt": 0,
"Users.DeleteAt": 0,
"GroupId": groupID,
"GroupMembers.GroupId": groupID,
})
if err := s.GetReplica().SelectBuilder(&groupMembers, builder); err != nil {
@ -467,23 +491,16 @@ func (s *SqlGroupStore) GetMemberUsersPage(groupID string, page int, perPage int
func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPage int, viewRestrictions *model.ViewUsersRestrictions, teammateNameDisplay string) ([]*model.User, error) {
groupMembers := []*model.User{}
userQuery := s.getQueryBuilder().
Select(`Users.*`).
From("GroupMembers").
Join("Users ON Users.Id = GroupMembers.UserId").
userQuery := s.groupMemberUsersSelectQuery.
Where(sq.Eq{"GroupMembers.DeleteAt": 0}).
Where(sq.Eq{"Users.DeleteAt": 0}).
Where(sq.Eq{"GroupId": groupID})
Where(sq.Eq{"GroupMembers.GroupId": groupID})
userQuery = applyViewRestrictionsFilter(userQuery, viewRestrictions, true)
queryString, args, err := userQuery.ToSql()
if err != nil {
return nil, errors.Wrap(err, "")
}
orderQuery := s.getQueryBuilder().
Select("Users.*").
From("(" + queryString + ") AS Users")
Select(getUsersColumns()...).
FromSelect(userQuery, "Users")
if teammateNameDisplay == model.ShowNicknameFullName {
orderQuery = orderQuery.OrderBy(`
@ -510,12 +527,7 @@ func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPa
Limit(uint64(perPage)).
Offset(uint64(page * perPage))
queryString, _, err = orderQuery.ToSql()
if err != nil {
return nil, errors.Wrap(err, "")
}
if err := s.GetReplica().Select(&groupMembers, queryString, args...); err != nil {
if err := s.GetReplica().SelectBuilder(&groupMembers, orderQuery); err != nil {
return nil, errors.Wrapf(err, "failed to find member Users for Group with id=%s", groupID)
}
@ -525,9 +537,7 @@ func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPa
func (s *SqlGroupStore) GetNonMemberUsersPage(groupID string, page int, perPage int, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, error) {
groupMembers := []*model.User{}
builder := s.getQueryBuilder().
Select("*").
From("UserGroups").
builder := s.userGroupsSelectQuery.
Where(sq.Eq{"Id": groupID})
if err := s.GetReplica().GetBuilder(&model.Group{}, builder); err != nil {
@ -535,7 +545,7 @@ func (s *SqlGroupStore) GetNonMemberUsersPage(groupID string, page int, perPage
}
builder = s.getQueryBuilder().
Select("Users.*").
Select(getUsersColumns()...).
From("Users").
LeftJoin("GroupMembers ON (GroupMembers.UserId = Users.Id AND GroupMembers.GroupId = ?)", groupID).
Where(sq.Eq{"Users.DeleteAt": 0}).
@ -568,14 +578,8 @@ func (s *SqlGroupStore) GetMemberCountWithRestrictions(groupID string, viewRestr
query = applyViewRestrictionsFilter(query, viewRestrictions, false)
queryString, args, err := query.ToSql()
if err != nil {
return int64(0), errors.Wrap(err, "")
}
var count int64
err = s.GetReplica().Get(&count, queryString, args...)
if err != nil {
if err := s.GetReplica().GetBuilder(&count, query); err != nil {
return int64(0), errors.Wrapf(err, "failed to count member Users for Group with id=%s", groupID)
}
@ -647,22 +651,22 @@ func (s *SqlGroupStore) GetMemberUsersNotInChannel(groupID string, channelID str
}
func (s *SqlGroupStore) UpsertMember(groupID string, userID string) (*model.GroupMember, error) {
members, query, args, err := s.buildUpsertMembersQuery(groupID, []string{userID})
members, query, err := s.buildUpsertMembersQuery(groupID, []string{userID})
if err != nil {
return nil, err
}
if _, err = s.GetMaster().Exec(query, args...); err != nil {
if _, err = s.GetMaster().ExecBuilder(query); err != nil {
return nil, errors.Wrap(err, "failed to save GroupMember")
}
return members[0], nil
}
func (s *SqlGroupStore) DeleteMember(groupID string, userID string) (*model.GroupMember, error) {
members, query, args, err := s.buildDeleteMembersQuery(groupID, []string{userID})
members, query, err := s.buildDeleteMembersQuery(groupID, []string{userID})
if err != nil {
return nil, err
}
if _, err = s.GetMaster().Exec(query, args...); err != nil {
if _, err = s.GetMaster().ExecBuilder(query); err != nil {
return nil, errors.Wrapf(err, "failed to update GroupMember with groupId=%s and userId=%s", groupID, userID)
}
@ -742,11 +746,17 @@ func (s *SqlGroupStore) getGroupSyncable(groupID string, syncableID string, sync
switch syncableType {
case model.GroupSyncableTypeTeam:
var team groupTeam
err = s.GetReplica().Get(&team, `SELECT * FROM GroupTeams WHERE GroupId=? AND TeamId=?`, groupID, syncableID)
err = s.GetReplica().GetBuilder(&team, s.groupTeamsSelectQuery.Where(sq.Eq{
"GroupTeams.GroupId": groupID,
"GroupTeams.TeamId": syncableID,
}))
result = &team
case model.GroupSyncableTypeChannel:
var ch groupChannel
err = s.GetReplica().Get(&ch, `SELECT * FROM GroupChannels WHERE GroupId=? AND ChannelId=?`, groupID, syncableID)
err = s.GetReplica().GetBuilder(&ch, s.groupChannelsSelectQuery.Where(sq.Eq{
"GroupChannels.GroupId": groupID,
"GroupChannels.ChannelId": syncableID,
}))
result = &ch
}
@ -790,19 +800,16 @@ func (s *SqlGroupStore) GetAllGroupSyncablesByGroupId(groupID string, syncableTy
switch syncableType {
case model.GroupSyncableTypeTeam:
sqlQuery := `
SELECT
GroupTeams.*,
Teams.DisplayName AS TeamDisplayName,
Teams.Type AS TeamType
FROM
GroupTeams
JOIN Teams ON Teams.Id = GroupTeams.TeamId
WHERE
GroupId = ? AND GroupTeams.DeleteAt = 0`
query := s.groupTeamsSelectQuery.
Columns("Teams.DisplayName AS TeamDisplayName", "Teams.Type AS TeamType").
Join("Teams ON Teams.Id = GroupTeams.TeamId").
Where(sq.Eq{
"GroupTeams.GroupId": groupID,
"GroupTeams.DeleteAt": 0,
})
results := []*groupTeamJoin{}
err := s.GetReplica().Select(&results, sqlQuery, groupID)
err := s.GetReplica().SelectBuilder(&results, query)
if err != nil {
return nil, errors.Wrapf(err, "failed to find GroupTeams with groupId=%s", groupID)
}
@ -822,23 +829,22 @@ func (s *SqlGroupStore) GetAllGroupSyncablesByGroupId(groupID string, syncableTy
groupSyncables = append(groupSyncables, groupSyncable)
}
case model.GroupSyncableTypeChannel:
sqlQuery := `
SELECT
GroupChannels.*,
Channels.DisplayName AS ChannelDisplayName,
Teams.DisplayName AS TeamDisplayName,
Channels.Type As ChannelType,
Teams.Type As TeamType,
Teams.Id AS TeamId
FROM
GroupChannels
JOIN Channels ON Channels.Id = GroupChannels.ChannelId
JOIN Teams ON Teams.Id = Channels.TeamId
WHERE
GroupId = ? AND GroupChannels.DeleteAt = 0`
query := s.groupChannelsSelectQuery.
Columns(
"Channels.DisplayName AS ChannelDisplayName",
"Teams.DisplayName AS TeamDisplayName",
"Channels.Type As ChannelType",
"Teams.Type As TeamType",
"Teams.Id AS TeamId",
).Join("Channels ON Channels.Id = GroupChannels.ChannelId").
Join("Teams ON Teams.Id = Channels.TeamId").
Where(sq.Eq{
"GroupChannels.GroupId": groupID,
"GroupChannels.DeleteAt": 0,
})
results := []*groupChannelJoin{}
err := s.GetReplica().Select(&results, sqlQuery, groupID)
err := s.GetReplica().SelectBuilder(&results, query)
if err != nil {
return nil, errors.Wrapf(err, "failed to find GroupChannels with groupId=%s", groupID)
}
@ -1902,22 +1908,22 @@ func (s *SqlGroupStore) countTableWithSelectAndWhere(selectStr, tableName string
}
func (s *SqlGroupStore) UpsertMembers(groupID string, userIDs []string) ([]*model.GroupMember, error) {
members, query, args, err := s.buildUpsertMembersQuery(groupID, userIDs)
members, query, err := s.buildUpsertMembersQuery(groupID, userIDs)
if err != nil {
return nil, err
}
if _, err = s.GetMaster().Exec(query, args...); err != nil {
if _, err = s.GetMaster().ExecBuilder(query); err != nil {
return nil, errors.Wrap(err, "failed to save GroupMember")
}
return members, err
}
func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, query string, args []any, err error) {
func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, builder sq.InsertBuilder, err error) {
var retrievedGroup model.Group
// Check Group exists
if err = s.GetReplica().Get(&retrievedGroup, "SELECT * FROM UserGroups WHERE Id = ?", groupID); err != nil {
if err = s.GetReplica().GetBuilder(&retrievedGroup, s.userGroupsSelectQuery.Where(sq.Eq{"UserGroups.Id": groupID})); err != nil {
err = errors.Wrapf(err, "failed to get UserGroup with groupId=%s", groupID)
return
}
@ -1927,7 +1933,7 @@ func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string
return
}
builder := s.getQueryBuilder().
builder = s.getQueryBuilder().
Insert("GroupMembers").
Columns("GroupId", "UserId", "CreateAt", "DeleteAt")
@ -1950,38 +1956,30 @@ func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string
builder = builder.SuffixExpr(sq.Expr("ON CONFLICT (groupid, userid) DO UPDATE SET CreateAt = ?, DeleteAt = ?", createAt, 0))
}
query, args, err = builder.ToSql()
return
}
func (s *SqlGroupStore) DeleteMembers(groupID string, userIDs []string) ([]*model.GroupMember, error) {
members, query, args, err := s.buildDeleteMembersQuery(groupID, userIDs)
members, query, err := s.buildDeleteMembersQuery(groupID, userIDs)
if err != nil {
return nil, err
}
if _, err = s.GetMaster().Exec(query, args...); err != nil {
if _, err = s.GetMaster().ExecBuilder(query); err != nil {
return nil, errors.Wrap(err, "failed to delete GroupMembers")
}
return members, err
}
func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, query string, args []any, err error) {
membersSelectQuery, membersSelectArgs, err := s.getQueryBuilder().
Select("*").
From("GroupMembers").
func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, builder sq.UpdateBuilder, err error) {
membersSelectQuery := s.groupMembersSelectQuery.
Where(sq.And{
sq.Eq{"GroupId": groupID},
sq.Eq{"UserId": userIDs},
sq.Eq{"DeleteAt": 0},
}).
ToSql()
if err != nil {
return
}
sq.Eq{"GroupMembers.GroupId": groupID},
sq.Eq{"GroupMembers.UserId": userIDs},
sq.Eq{"GroupMembers.DeleteAt": 0},
})
err = s.GetReplica().Select(&members, membersSelectQuery, membersSelectArgs...)
if err != nil {
if err = s.GetReplica().SelectBuilder(&members, membersSelectQuery); err != nil {
return
}
if len(members) != len(userIDs) {
@ -2003,7 +2001,7 @@ func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string
member.DeleteAt = deleteAt
}
builder := s.getQueryBuilder().
builder = s.getQueryBuilder().
Update("GroupMembers").
Set("DeleteAt", deleteAt).
Where(sq.And{
@ -2011,6 +2009,5 @@ func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string
sq.Eq{"UserId": userIDs},
})
query, args, err = builder.ToSql()
return
}

View file

@ -34,6 +34,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("Update", func(t *testing.T) { testGroupStoreUpdate(t, rctx, ss) })
t.Run("Delete", func(t *testing.T) { testGroupStoreDelete(t, rctx, ss) })
t.Run("Restore", func(t *testing.T) { testGroupStoreRestore(t, rctx, ss) })
t.Run("ToModelChannelAssociations", func(t *testing.T) { testGroupStoreToModelChannelAssociations(t, rctx, ss) })
t.Run("GetMemberUsers", func(t *testing.T) { testGroupGetMemberUsers(t, rctx, ss) })
t.Run("GetMemberUsersPage", func(t *testing.T) { testGroupGetMemberUsersPage(t, rctx, ss) })
@ -50,6 +51,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("CreateGroupSyncable", func(t *testing.T) { testCreateGroupSyncable(t, rctx, ss) })
t.Run("GetGroupSyncable", func(t *testing.T) { testGetGroupSyncable(t, rctx, ss) })
t.Run("GetGroupSyncableErrors", func(t *testing.T) { testGetGroupSyncableErrors(t, rctx, ss) })
t.Run("GetAllGroupSyncablesByGroupId", func(t *testing.T) { testGetAllGroupSyncablesByGroup(t, rctx, ss) })
t.Run("UpdateGroupSyncable", func(t *testing.T) { testUpdateGroupSyncable(t, rctx, ss) })
t.Run("DeleteGroupSyncable", func(t *testing.T) { testDeleteGroupSyncable(t, rctx, ss) })
@ -74,6 +76,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("TeamMembersMinusGroupMembers", func(t *testing.T) { testTeamMembersMinusGroupMembers(t, rctx, ss) })
t.Run("ChannelMembersMinusGroupMembers", func(t *testing.T) { testChannelMembersMinusGroupMembers(t, rctx, ss) })
t.Run("CountMembersMinusGroupMembers", func(t *testing.T) { testCountMembersMinusGroupMembers(t, rctx, ss) })
t.Run("GetMemberCount", func(t *testing.T) { groupTestGetMemberCount(t, rctx, ss) })
@ -1618,6 +1621,44 @@ func testGetGroupSyncable(t *testing.T, rctx request.CTX, ss store.Store) {
require.Zero(t, gt1.DeleteAt)
}
func testGetGroupSyncableErrors(t *testing.T, rctx request.CTX, ss store.Store) {
// Create a group
g1 := &model.Group{
Name: model.NewPointer(model.NewId()),
DisplayName: model.NewId(),
Description: model.NewId(),
Source: model.GroupSourceLdap,
RemoteId: model.NewPointer(model.NewId()),
}
group, err := ss.Group().Create(g1)
require.NoError(t, err)
// Test with invalid syncable type
invalidSyncableType := model.GroupSyncableType("invalid")
_, err = ss.Group().GetGroupSyncable(group.Id, model.NewId(), invalidSyncableType)
require.Error(t, err)
var nfErr *store.ErrNotFound
require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err)
// Test with empty group ID
_, err = ss.Group().GetGroupSyncable("", model.NewId(), model.GroupSyncableTypeTeam)
require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err)
// Test with empty syncable ID
_, err = ss.Group().GetGroupSyncable(group.Id, "", model.GroupSyncableTypeTeam)
require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err)
// Test with completely non-existent IDs
randomGroupId := model.NewId()
randomTeamId := model.NewId()
_, err = ss.Group().GetGroupSyncable(randomGroupId, randomTeamId, model.GroupSyncableTypeTeam)
require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err)
// Test with valid group ID but non-existent syncable
_, err = ss.Group().GetGroupSyncable(group.Id, model.NewId(), model.GroupSyncableTypeTeam)
require.True(t, errors.As(err, &nfErr))
}
func testGetAllGroupSyncablesByGroup(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("team", func(t *testing.T) { testGetAllGroupSyncablesByGroupTeam(t, rctx, ss) })
t.Run("channel", func(t *testing.T) { testGetAllGroupSyncablesByGroupChannel(t, rctx, ss) })
@ -4174,17 +4215,14 @@ func testGetGroups(t *testing.T, rctx request.CTX, ss store.Store) {
PerPage: 100,
Resultf: func(groups []*model.Group) bool {
for _, group := range groups {
fmt.Println(group.Id, group.ChannelMemberCount)
var channelMemberCount int
if group.ChannelMemberCount != nil {
channelMemberCount = *group.ChannelMemberCount
}
if group.Id == group1.Id && channelMemberCount != 2 {
fmt.Println("group1", group.Id, channelMemberCount)
return false
}
if group.Id == group2.Id && channelMemberCount != 1 {
fmt.Println("group2", group.Id, channelMemberCount)
return false
}
}
@ -4205,11 +4243,9 @@ func testGetGroups(t *testing.T, rctx request.CTX, ss store.Store) {
channelMemberCount = *group.ChannelMemberCount
}
if group.Id == group1.Id && channelMemberCount != 1 {
fmt.Println("group1", group.Id, channelMemberCount)
return false
}
if group.Id == group2.Id && channelMemberCount != 2 {
fmt.Println("group2", group.Id, channelMemberCount)
return false
}
}
@ -5800,3 +5836,266 @@ func groupTestGroupCountBySource(t *testing.T, rctx request.CTX, ss store.Store)
require.NoError(t, err)
require.Equal(t, ldapSourceCountAfter-1, ldapSourceCountAfterDelete)
}
func testCountMembersMinusGroupMembers(t *testing.T, rctx request.CTX, ss store.Store) {
// Create test users
u1, err := ss.User().Save(rctx, &model.User{
Email: MakeEmail(),
Username: model.NewUsername(),
})
require.NoError(t, err)
u2, err := ss.User().Save(rctx, &model.User{
Email: MakeEmail(),
Username: model.NewUsername(),
})
require.NoError(t, err)
u3, err := ss.User().Save(rctx, &model.User{
Email: MakeEmail(),
Username: model.NewUsername(),
})
require.NoError(t, err)
// Create test team
team := &model.Team{
DisplayName: "Name",
Description: "Some description",
CompanyName: "Some company name",
AllowOpenInvite: false,
InviteId: "inviteid0",
Name: "z-z-" + model.NewId() + "a",
Email: "success+" + model.NewId() + "@simulator.amazonses.com",
Type: model.TeamOpen,
}
team, err = ss.Team().Save(team)
require.NoError(t, err)
// Create a second team to test team-specific counts
team2 := &model.Team{
DisplayName: "Name 2",
Description: "Some description 2",
CompanyName: "Some company name 2",
AllowOpenInvite: false,
InviteId: "inviteid1",
Name: "z-z-" + model.NewId() + "b",
Email: "success+" + model.NewId() + "@simulator.amazonses.com",
Type: model.TeamOpen,
}
team2, err = ss.Team().Save(team2)
require.NoError(t, err)
// Create test channel in team 1
channel := &model.Channel{
TeamId: team.Id,
DisplayName: "Display Name",
Name: "z-z-" + model.NewId() + "a",
Type: model.ChannelTypeOpen,
}
channel, nErr := ss.Channel().Save(rctx, channel, -1)
require.NoError(t, nErr)
// Create test channel in team 2
channel2 := &model.Channel{
TeamId: team2.Id,
DisplayName: "Display Name 2",
Name: "z-z-" + model.NewId() + "b",
Type: model.ChannelTypeOpen,
}
channel2, nErr = ss.Channel().Save(rctx, channel2, -1)
require.NoError(t, nErr)
// Add users to teams and channels
// u1 and u2 in team 1
_, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: u1.Id}, -1)
require.NoError(t, nErr)
_, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: u2.Id}, -1)
require.NoError(t, nErr)
// u3 in team 2
_, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team2.Id, UserId: u3.Id}, -1)
require.NoError(t, nErr)
// u1 and u2 in channel 1
_, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel.Id, UserId: u1.Id, NotifyProps: model.GetDefaultChannelNotifyProps()})
require.NoError(t, nErr)
_, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel.Id, UserId: u2.Id, NotifyProps: model.GetDefaultChannelNotifyProps()})
require.NoError(t, nErr)
// u3 in channel 2
_, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel2.Id, UserId: u3.Id, NotifyProps: model.GetDefaultChannelNotifyProps()})
require.NoError(t, nErr)
// Create groups
group1, err := ss.Group().Create(&model.Group{
Name: model.NewPointer(model.NewId()),
DisplayName: model.NewId(),
Description: model.NewId(),
Source: model.GroupSourceCustom,
RemoteId: model.NewPointer(model.NewId()),
})
require.NoError(t, err)
group2, err := ss.Group().Create(&model.Group{
Name: model.NewPointer(model.NewId()),
DisplayName: model.NewId(),
Description: model.NewId(),
Source: model.GroupSourceCustom,
RemoteId: model.NewPointer(model.NewId()),
})
require.NoError(t, err)
// Add u1 to group1, u3 to group2
_, err = ss.Group().UpsertMember(group1.Id, u1.Id)
require.NoError(t, err)
_, err = ss.Group().UpsertMember(group2.Id, u3.Id)
require.NoError(t, err)
// Test CountTeamMembersMinusGroupMembers with empty groupIDs
count, err := ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{})
require.NoError(t, err)
require.Equal(t, int64(2), count) // Both u1 and u2 are counted when no groups are excluded
// Test CountTeamMembersMinusGroupMembers with group1 ID
count, err = ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{group1.Id})
require.NoError(t, err)
require.Equal(t, int64(1), count) // Only u2 should be counted (not in the group)
// Test with non-existent team ID
count, err = ss.Group().CountTeamMembersMinusGroupMembers(model.NewId(), []string{group1.Id})
require.NoError(t, err)
require.Equal(t, int64(0), count) // No members in a non-existent team
// Test with multiple group IDs
count, err = ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{group1.Id, group2.Id})
require.NoError(t, err)
require.Equal(t, int64(1), count) // Only u2 should be counted
// Test team 2
count, err = ss.Group().CountTeamMembersMinusGroupMembers(team2.Id, []string{group2.Id})
require.NoError(t, err)
require.Equal(t, int64(0), count) // No one should be counted (u3 is in group2)
// Test CountChannelMembersMinusGroupMembers with empty groupIDs
count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{})
require.NoError(t, err)
require.Equal(t, int64(2), count) // Both users are counted when no groups are excluded
// Test CountChannelMembersMinusGroupMembers with the group ID
count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{group1.Id})
require.NoError(t, err)
require.Equal(t, int64(1), count) // Only u2 should be counted (not in the group)
// Test with multiple group IDs
count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{group1.Id, group2.Id})
require.NoError(t, err)
require.Equal(t, int64(1), count) // u2 should be counted
// Test channel 2
count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel2.Id, []string{group2.Id})
require.NoError(t, err)
require.Equal(t, int64(0), count) // No one should be counted (u3 is in group2)
// Test with non-existent channel ID
count, err = ss.Group().CountChannelMembersMinusGroupMembers(model.NewId(), []string{group1.Id})
require.NoError(t, err)
require.Equal(t, int64(0), count) // No members in a non-existent channel
// Test error cases - passing invalid parameters
// 1. Empty team ID
count, err = ss.Group().CountTeamMembersMinusGroupMembers("", []string{group1.Id})
require.NoError(t, err) // Should handle this gracefully
require.Equal(t, int64(0), count)
// 2. Empty channel ID
count, err = ss.Group().CountChannelMembersMinusGroupMembers("", []string{group1.Id})
require.NoError(t, err) // Should handle this gracefully
require.Equal(t, int64(0), count)
}
func testGroupStoreToModelChannelAssociations(t *testing.T, rctx request.CTX, ss store.Store) {
// Create test group
group, err := ss.Group().Create(&model.Group{
Name: model.NewPointer(model.NewId()),
DisplayName: model.NewId(),
Description: model.NewId(),
Source: model.GroupSourceCustom,
RemoteId: model.NewPointer(model.NewId()),
})
require.NoError(t, err)
require.NotNil(t, group)
// Create test team
team := &model.Team{
DisplayName: "Name",
Description: "Some description",
CompanyName: "Some company name",
AllowOpenInvite: false,
InviteId: "inviteid0",
Name: "z-z-" + model.NewId() + "a",
Email: "success+" + model.NewId() + "@simulator.amazonses.com",
Type: model.TeamOpen,
}
team, err = ss.Team().Save(team)
require.NoError(t, err)
require.NotNil(t, team)
// Create test channel 1
channel1 := &model.Channel{
TeamId: team.Id,
DisplayName: "Display Name 1",
Name: "z-z-" + model.NewId() + "a",
Type: model.ChannelTypeOpen,
}
channel1, nErr := ss.Channel().Save(rctx, channel1, -1)
require.NoError(t, nErr)
require.NotNil(t, channel1)
// Create test channel 2
channel2 := &model.Channel{
TeamId: team.Id,
DisplayName: "Display Name 2",
Name: "z-z-" + model.NewId() + "b",
Type: model.ChannelTypeOpen,
}
channel2, nErr = ss.Channel().Save(rctx, channel2, -1)
require.NoError(t, nErr)
require.NotNil(t, channel2)
// Create group channel syncables
_, err = ss.Group().CreateGroupSyncable(&model.GroupSyncable{
GroupId: group.Id,
SyncableId: channel1.Id,
Type: model.GroupSyncableTypeChannel,
SchemeAdmin: true,
})
require.NoError(t, err)
_, err = ss.Group().CreateGroupSyncable(&model.GroupSyncable{
GroupId: group.Id,
SyncableId: channel2.Id,
Type: model.GroupSyncableTypeChannel,
SchemeAdmin: false,
})
require.NoError(t, err)
// Test the GetGroupsAssociatedToChannelsByTeam function
// This exercises the groupsAssociatedToChannelWithSchemeAdmin.ToModel method
result, err := ss.Group().GetGroupsAssociatedToChannelsByTeam(team.Id, model.GroupSearchOpts{})
require.NoError(t, err)
require.NotNil(t, result)
// Verify channel 1 results
require.Contains(t, result, channel1.Id)
require.NotEmpty(t, result[channel1.Id])
require.Equal(t, group.Id, result[channel1.Id][0].Id)
require.NotNil(t, result[channel1.Id][0].SchemeAdmin)
require.True(t, *result[channel1.Id][0].SchemeAdmin)
// Verify channel 2 results (with different SchemeAdmin value)
require.Contains(t, result, channel2.Id)
require.NotEmpty(t, result[channel2.Id])
require.Equal(t, group.Id, result[channel2.Id][0].Id)
require.NotNil(t, result[channel2.Id][0].SchemeAdmin)
require.False(t, *result[channel2.Id][0].SchemeAdmin)
}