diff --git a/server/channels/store/sqlstore/group_store.go b/server/channels/store/sqlstore/group_store.go index 41b081d23f2..22ea78b5170 100644 --- a/server/channels/store/sqlstore/group_store.go +++ b/server/channels/store/sqlstore/group_store.go @@ -49,10 +49,67 @@ type groupChannelJoin struct { type SqlGroupStore struct { *SqlStore + userGroupsSelectQuery sq.SelectBuilder + groupMembersSelectQuery sq.SelectBuilder + groupMemberUsersSelectQuery sq.SelectBuilder + groupTeamsSelectQuery sq.SelectBuilder + groupChannelsSelectQuery sq.SelectBuilder } func newSqlGroupStore(sqlStore *SqlStore) store.GroupStore { - return &SqlGroupStore{SqlStore: sqlStore} + s := &SqlGroupStore{SqlStore: sqlStore} + + s.userGroupsSelectQuery = s.getQueryBuilder(). + Select( + "UserGroups.Id", + "UserGroups.Name", + "UserGroups.DisplayName", + "UserGroups.Description", + "UserGroups.Source", + "UserGroups.RemoteId", + "UserGroups.CreateAt", + "UserGroups.UpdateAt", + "UserGroups.DeleteAt", + "UserGroups.AllowReference", + ). + From("UserGroups") + + s.groupMembersSelectQuery = s.getQueryBuilder(). + Select( + "GroupMembers.GroupId", + "GroupMembers.UserId", + "GroupMembers.CreateAt", + "GroupMembers.DeleteAt", + ).From("GroupMembers") + + s.groupMemberUsersSelectQuery = s.getQueryBuilder(). + Select(getUsersColumns()...). + From("GroupMembers"). + Join("Users ON Users.Id = GroupMembers.UserId") + + s.groupTeamsSelectQuery = s.getQueryBuilder(). + Select( + "GroupTeams.GroupId", + "GroupTeams.TeamId", + "GroupTeams.AutoAdd", + "GroupTeams.SchemeAdmin", + "GroupTeams.CreateAt", + "GroupTeams.UpdateAt", + "GroupTeams.DeleteAt", + ).From("GroupTeams") + + s.groupChannelsSelectQuery = s.getQueryBuilder(). + Select( + "GroupChannels.GroupId", + "GroupChannels.ChannelId", + "GroupChannels.AutoAdd", + "GroupChannels.SchemeAdmin", + "GroupChannels.CreateAt", + "GroupChannels.UpdateAt", + "GroupChannels.DeleteAt", + ).From("GroupChannels") + + return s } func (s *SqlGroupStore) Create(group *model.Group) (*model.Group, error) { @@ -132,14 +189,9 @@ func (s *SqlGroupStore) CreateWithUserIds(g *model.GroupWithUserIds) (_ *model.G return nil, err } - // Get the new Group along with the member count - groupGroupQuery := ` - SELECT - UserGroups.*, - A.Count AS MemberCount - FROM - UserGroups - INNER JOIN ( + groupGroupQuery := s.userGroupsSelectQuery. + Column("A.Count AS MemberCount"). + InnerJoin(`( SELECT UserGroups.Id, COUNT(GroupMembers.UserId) AS Count @@ -154,12 +206,12 @@ func (s *SqlGroupStore) CreateWithUserIds(g *model.GroupWithUserIds) (_ *model.G UserGroups.DisplayName, UserGroups.Id LIMIT - ? OFFSET ? - ) AS A ON UserGroups.Id = A.Id - ORDER BY - UserGroups.CreateAt DESC` + 1 OFFSET 0 + ) AS A ON UserGroups.Id = A.Id`, g.Id). + OrderBy("UserGroups.CreateAt DESC") + var newGroup group - if err = txn.Get(&newGroup, groupGroupQuery, g.Id, 1, 0); err != nil { + if err = txn.GetBuilder(&newGroup, groupGroupQuery); err != nil { return nil, err } if err = txn.Commit(); err != nil { @@ -172,16 +224,13 @@ func (s *SqlGroupStore) checkUsersExist(userIDs []string) error { if len(userIDs) == 0 { return nil } - usersSelectQuery, usersSelectArgs, err := s.getQueryBuilder(). + usersSelectQuery := s.getQueryBuilder(). Select("Id"). From("Users"). - Where(sq.Eq{"Id": userIDs, "DeleteAt": 0}). - ToSql() - if err != nil { - return err - } + Where(sq.Eq{"Id": userIDs, "DeleteAt": 0}) + var rows []string - err = s.GetReplica().Select(&rows, usersSelectQuery, usersSelectArgs...) + err := s.GetReplica().SelectBuilder(&rows, usersSelectQuery) if err != nil { return err } @@ -215,9 +264,7 @@ func (s *SqlGroupStore) buildInsertGroupUsersQuery(groupId string, userIds []str func (s *SqlGroupStore) Get(groupId string) (*model.Group, error) { var group model.Group - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.Eq{"Id": groupId}) if err := s.GetReplica().GetBuilder(&group, builder); err != nil { @@ -232,16 +279,14 @@ func (s *SqlGroupStore) Get(groupId string) (*model.Group, error) { func (s *SqlGroupStore) GetByName(name string, opts model.GroupSearchOpts) (*model.Group, error) { var group model.Group - query := s.getQueryBuilder().Select("*").From("UserGroups").Where(sq.Eq{"Name": name}) + query := s.userGroupsSelectQuery. + Where(sq.Eq{"Name": name}) + if opts.FilterAllowReference { query = query.Where("AllowReference = true") } - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_by_name_tosql") - } - if err := s.GetReplica().Get(&group, queryString, args...); err != nil { + if err := s.GetReplica().GetBuilder(&group, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Group", fmt.Sprintf("name=%s", name)) } @@ -253,12 +298,8 @@ func (s *SqlGroupStore) GetByName(name string, opts model.GroupSearchOpts) (*mod func (s *SqlGroupStore) GetByIDs(groupIDs []string) ([]*model.Group, error) { groups := []*model.Group{} - query := s.getQueryBuilder().Select("*").From("UserGroups").Where(sq.Eq{"Id": groupIDs}) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_by_ids_tosql") - } - if err := s.GetReplica().Select(&groups, queryString, args...); err != nil { + query := s.userGroupsSelectQuery.Where(sq.Eq{"Id": groupIDs}) + if err := s.GetReplica().SelectBuilder(&groups, query); err != nil { return nil, errors.Wrap(err, "failed to find Groups by ids") } return groups, nil @@ -266,9 +307,7 @@ func (s *SqlGroupStore) GetByIDs(groupIDs []string) ([]*model.Group, error) { func (s *SqlGroupStore) GetByRemoteID(remoteID string, groupSource model.GroupSource) (*model.Group, error) { var group model.Group - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.Eq{ "RemoteId": remoteID, "Source": groupSource, @@ -286,9 +325,7 @@ func (s *SqlGroupStore) GetByRemoteID(remoteID string, groupSource model.GroupSo func (s *SqlGroupStore) GetAllBySource(groupSource model.GroupSource) ([]*model.Group, error) { groups := []*model.Group{} - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.Eq{ "DeleteAt": 0, "Source": groupSource, @@ -304,13 +341,11 @@ func (s *SqlGroupStore) GetAllBySource(groupSource model.GroupSource) ([]*model. func (s *SqlGroupStore) GetByUser(userID string, opts model.GroupSearchOpts) ([]*model.Group, error) { groups := []*model.Group{} - builder := s.getQueryBuilder(). - Select("UserGroups.*"). - From("GroupMembers"). - Join("UserGroups ON UserGroups.Id = GroupMembers.GroupId"). + builder := s.userGroupsSelectQuery. + Join("GroupMembers ON GroupMembers.GroupId = UserGroups.Id"). Where(sq.Eq{ "GroupMembers.DeleteAt": 0, - "UserId": userID, + "GroupMembers.UserId": userID, }) if opts.FilterAllowReference { @@ -326,10 +361,7 @@ func (s *SqlGroupStore) GetByUser(userID string, opts model.GroupSearchOpts) ([] func (s *SqlGroupStore) Update(group *model.Group) (*model.Group, error) { var retrievedGroup model.Group - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). - Where(sq.Eq{"Id": group.Id}) + builder := s.userGroupsSelectQuery.Where(sq.Eq{"Id": group.Id}) if err := s.GetReplica().GetBuilder(&retrievedGroup, builder); err != nil { if err == sql.ErrNoRows { @@ -371,9 +403,7 @@ func (s *SqlGroupStore) Update(group *model.Group) (*model.Group, error) { func (s *SqlGroupStore) Delete(groupID string) (*model.Group, error) { var group model.Group - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.Eq{ "Id": groupID, "DeleteAt": 0, @@ -400,9 +430,7 @@ func (s *SqlGroupStore) Delete(groupID string) (*model.Group, error) { func (s *SqlGroupStore) Restore(groupID string) (*model.Group, error) { var group model.Group - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.And{ sq.Eq{"Id": groupID}, sq.NotEq{"DeleteAt": 0}, @@ -427,12 +455,11 @@ func (s *SqlGroupStore) Restore(groupID string) (*model.Group, error) { } func (s *SqlGroupStore) GetMember(groupID, userID string) (*model.GroupMember, error) { - builder := s.getQueryBuilder(). - Select("*"). - From("GroupMembers"). + builder := s.groupMembersSelectQuery. Where(sq.Eq{"UserId": userID}). Where(sq.Eq{"GroupId": groupID}). Where(sq.Eq{"DeleteAt": 0}) + var groupMember model.GroupMember if err := s.GetReplica().GetBuilder(&groupMember, builder); err != nil { return nil, errors.Wrap(err, "GetMember") @@ -443,14 +470,11 @@ func (s *SqlGroupStore) GetMember(groupID, userID string) (*model.GroupMember, e func (s *SqlGroupStore) GetMemberUsers(groupID string) ([]*model.User, error) { groupMembers := []*model.User{} - builder := s.getQueryBuilder(). - Select("Users.*"). - From("GroupMembers"). - Join("Users ON Users.Id = GroupMembers.UserId"). + builder := s.groupMemberUsersSelectQuery. Where(sq.Eq{ "GroupMembers.DeleteAt": 0, "Users.DeleteAt": 0, - "GroupId": groupID, + "GroupMembers.GroupId": groupID, }) if err := s.GetReplica().SelectBuilder(&groupMembers, builder); err != nil { @@ -467,23 +491,16 @@ func (s *SqlGroupStore) GetMemberUsersPage(groupID string, page int, perPage int func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPage int, viewRestrictions *model.ViewUsersRestrictions, teammateNameDisplay string) ([]*model.User, error) { groupMembers := []*model.User{} - userQuery := s.getQueryBuilder(). - Select(`Users.*`). - From("GroupMembers"). - Join("Users ON Users.Id = GroupMembers.UserId"). + userQuery := s.groupMemberUsersSelectQuery. Where(sq.Eq{"GroupMembers.DeleteAt": 0}). Where(sq.Eq{"Users.DeleteAt": 0}). - Where(sq.Eq{"GroupId": groupID}) + Where(sq.Eq{"GroupMembers.GroupId": groupID}) userQuery = applyViewRestrictionsFilter(userQuery, viewRestrictions, true) - queryString, args, err := userQuery.ToSql() - if err != nil { - return nil, errors.Wrap(err, "") - } orderQuery := s.getQueryBuilder(). - Select("Users.*"). - From("(" + queryString + ") AS Users") + Select(getUsersColumns()...). + FromSelect(userQuery, "Users") if teammateNameDisplay == model.ShowNicknameFullName { orderQuery = orderQuery.OrderBy(` @@ -510,12 +527,7 @@ func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPa Limit(uint64(perPage)). Offset(uint64(page * perPage)) - queryString, _, err = orderQuery.ToSql() - if err != nil { - return nil, errors.Wrap(err, "") - } - - if err := s.GetReplica().Select(&groupMembers, queryString, args...); err != nil { + if err := s.GetReplica().SelectBuilder(&groupMembers, orderQuery); err != nil { return nil, errors.Wrapf(err, "failed to find member Users for Group with id=%s", groupID) } @@ -525,9 +537,7 @@ func (s *SqlGroupStore) GetMemberUsersSortedPage(groupID string, page int, perPa func (s *SqlGroupStore) GetNonMemberUsersPage(groupID string, page int, perPage int, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, error) { groupMembers := []*model.User{} - builder := s.getQueryBuilder(). - Select("*"). - From("UserGroups"). + builder := s.userGroupsSelectQuery. Where(sq.Eq{"Id": groupID}) if err := s.GetReplica().GetBuilder(&model.Group{}, builder); err != nil { @@ -535,7 +545,7 @@ func (s *SqlGroupStore) GetNonMemberUsersPage(groupID string, page int, perPage } builder = s.getQueryBuilder(). - Select("Users.*"). + Select(getUsersColumns()...). From("Users"). LeftJoin("GroupMembers ON (GroupMembers.UserId = Users.Id AND GroupMembers.GroupId = ?)", groupID). Where(sq.Eq{"Users.DeleteAt": 0}). @@ -568,14 +578,8 @@ func (s *SqlGroupStore) GetMemberCountWithRestrictions(groupID string, viewRestr query = applyViewRestrictionsFilter(query, viewRestrictions, false) - queryString, args, err := query.ToSql() - if err != nil { - return int64(0), errors.Wrap(err, "") - } - var count int64 - err = s.GetReplica().Get(&count, queryString, args...) - if err != nil { + if err := s.GetReplica().GetBuilder(&count, query); err != nil { return int64(0), errors.Wrapf(err, "failed to count member Users for Group with id=%s", groupID) } @@ -647,22 +651,22 @@ func (s *SqlGroupStore) GetMemberUsersNotInChannel(groupID string, channelID str } func (s *SqlGroupStore) UpsertMember(groupID string, userID string) (*model.GroupMember, error) { - members, query, args, err := s.buildUpsertMembersQuery(groupID, []string{userID}) + members, query, err := s.buildUpsertMembersQuery(groupID, []string{userID}) if err != nil { return nil, err } - if _, err = s.GetMaster().Exec(query, args...); err != nil { + if _, err = s.GetMaster().ExecBuilder(query); err != nil { return nil, errors.Wrap(err, "failed to save GroupMember") } return members[0], nil } func (s *SqlGroupStore) DeleteMember(groupID string, userID string) (*model.GroupMember, error) { - members, query, args, err := s.buildDeleteMembersQuery(groupID, []string{userID}) + members, query, err := s.buildDeleteMembersQuery(groupID, []string{userID}) if err != nil { return nil, err } - if _, err = s.GetMaster().Exec(query, args...); err != nil { + if _, err = s.GetMaster().ExecBuilder(query); err != nil { return nil, errors.Wrapf(err, "failed to update GroupMember with groupId=%s and userId=%s", groupID, userID) } @@ -742,11 +746,17 @@ func (s *SqlGroupStore) getGroupSyncable(groupID string, syncableID string, sync switch syncableType { case model.GroupSyncableTypeTeam: var team groupTeam - err = s.GetReplica().Get(&team, `SELECT * FROM GroupTeams WHERE GroupId=? AND TeamId=?`, groupID, syncableID) + err = s.GetReplica().GetBuilder(&team, s.groupTeamsSelectQuery.Where(sq.Eq{ + "GroupTeams.GroupId": groupID, + "GroupTeams.TeamId": syncableID, + })) result = &team case model.GroupSyncableTypeChannel: var ch groupChannel - err = s.GetReplica().Get(&ch, `SELECT * FROM GroupChannels WHERE GroupId=? AND ChannelId=?`, groupID, syncableID) + err = s.GetReplica().GetBuilder(&ch, s.groupChannelsSelectQuery.Where(sq.Eq{ + "GroupChannels.GroupId": groupID, + "GroupChannels.ChannelId": syncableID, + })) result = &ch } @@ -790,19 +800,16 @@ func (s *SqlGroupStore) GetAllGroupSyncablesByGroupId(groupID string, syncableTy switch syncableType { case model.GroupSyncableTypeTeam: - sqlQuery := ` - SELECT - GroupTeams.*, - Teams.DisplayName AS TeamDisplayName, - Teams.Type AS TeamType - FROM - GroupTeams - JOIN Teams ON Teams.Id = GroupTeams.TeamId - WHERE - GroupId = ? AND GroupTeams.DeleteAt = 0` + query := s.groupTeamsSelectQuery. + Columns("Teams.DisplayName AS TeamDisplayName", "Teams.Type AS TeamType"). + Join("Teams ON Teams.Id = GroupTeams.TeamId"). + Where(sq.Eq{ + "GroupTeams.GroupId": groupID, + "GroupTeams.DeleteAt": 0, + }) results := []*groupTeamJoin{} - err := s.GetReplica().Select(&results, sqlQuery, groupID) + err := s.GetReplica().SelectBuilder(&results, query) if err != nil { return nil, errors.Wrapf(err, "failed to find GroupTeams with groupId=%s", groupID) } @@ -822,23 +829,22 @@ func (s *SqlGroupStore) GetAllGroupSyncablesByGroupId(groupID string, syncableTy groupSyncables = append(groupSyncables, groupSyncable) } case model.GroupSyncableTypeChannel: - sqlQuery := ` - SELECT - GroupChannels.*, - Channels.DisplayName AS ChannelDisplayName, - Teams.DisplayName AS TeamDisplayName, - Channels.Type As ChannelType, - Teams.Type As TeamType, - Teams.Id AS TeamId - FROM - GroupChannels - JOIN Channels ON Channels.Id = GroupChannels.ChannelId - JOIN Teams ON Teams.Id = Channels.TeamId - WHERE - GroupId = ? AND GroupChannels.DeleteAt = 0` + query := s.groupChannelsSelectQuery. + Columns( + "Channels.DisplayName AS ChannelDisplayName", + "Teams.DisplayName AS TeamDisplayName", + "Channels.Type As ChannelType", + "Teams.Type As TeamType", + "Teams.Id AS TeamId", + ).Join("Channels ON Channels.Id = GroupChannels.ChannelId"). + Join("Teams ON Teams.Id = Channels.TeamId"). + Where(sq.Eq{ + "GroupChannels.GroupId": groupID, + "GroupChannels.DeleteAt": 0, + }) results := []*groupChannelJoin{} - err := s.GetReplica().Select(&results, sqlQuery, groupID) + err := s.GetReplica().SelectBuilder(&results, query) if err != nil { return nil, errors.Wrapf(err, "failed to find GroupChannels with groupId=%s", groupID) } @@ -1902,22 +1908,22 @@ func (s *SqlGroupStore) countTableWithSelectAndWhere(selectStr, tableName string } func (s *SqlGroupStore) UpsertMembers(groupID string, userIDs []string) ([]*model.GroupMember, error) { - members, query, args, err := s.buildUpsertMembersQuery(groupID, userIDs) + members, query, err := s.buildUpsertMembersQuery(groupID, userIDs) if err != nil { return nil, err } - if _, err = s.GetMaster().Exec(query, args...); err != nil { + if _, err = s.GetMaster().ExecBuilder(query); err != nil { return nil, errors.Wrap(err, "failed to save GroupMember") } return members, err } -func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, query string, args []any, err error) { +func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, builder sq.InsertBuilder, err error) { var retrievedGroup model.Group // Check Group exists - if err = s.GetReplica().Get(&retrievedGroup, "SELECT * FROM UserGroups WHERE Id = ?", groupID); err != nil { + if err = s.GetReplica().GetBuilder(&retrievedGroup, s.userGroupsSelectQuery.Where(sq.Eq{"UserGroups.Id": groupID})); err != nil { err = errors.Wrapf(err, "failed to get UserGroup with groupId=%s", groupID) return } @@ -1927,7 +1933,7 @@ func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string return } - builder := s.getQueryBuilder(). + builder = s.getQueryBuilder(). Insert("GroupMembers"). Columns("GroupId", "UserId", "CreateAt", "DeleteAt") @@ -1950,38 +1956,30 @@ func (s *SqlGroupStore) buildUpsertMembersQuery(groupID string, userIDs []string builder = builder.SuffixExpr(sq.Expr("ON CONFLICT (groupid, userid) DO UPDATE SET CreateAt = ?, DeleteAt = ?", createAt, 0)) } - query, args, err = builder.ToSql() return } func (s *SqlGroupStore) DeleteMembers(groupID string, userIDs []string) ([]*model.GroupMember, error) { - members, query, args, err := s.buildDeleteMembersQuery(groupID, userIDs) + members, query, err := s.buildDeleteMembersQuery(groupID, userIDs) if err != nil { return nil, err } - if _, err = s.GetMaster().Exec(query, args...); err != nil { + if _, err = s.GetMaster().ExecBuilder(query); err != nil { return nil, errors.Wrap(err, "failed to delete GroupMembers") } return members, err } -func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, query string, args []any, err error) { - membersSelectQuery, membersSelectArgs, err := s.getQueryBuilder(). - Select("*"). - From("GroupMembers"). +func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string) (members []*model.GroupMember, builder sq.UpdateBuilder, err error) { + membersSelectQuery := s.groupMembersSelectQuery. Where(sq.And{ - sq.Eq{"GroupId": groupID}, - sq.Eq{"UserId": userIDs}, - sq.Eq{"DeleteAt": 0}, - }). - ToSql() - if err != nil { - return - } + sq.Eq{"GroupMembers.GroupId": groupID}, + sq.Eq{"GroupMembers.UserId": userIDs}, + sq.Eq{"GroupMembers.DeleteAt": 0}, + }) - err = s.GetReplica().Select(&members, membersSelectQuery, membersSelectArgs...) - if err != nil { + if err = s.GetReplica().SelectBuilder(&members, membersSelectQuery); err != nil { return } if len(members) != len(userIDs) { @@ -2003,7 +2001,7 @@ func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string member.DeleteAt = deleteAt } - builder := s.getQueryBuilder(). + builder = s.getQueryBuilder(). Update("GroupMembers"). Set("DeleteAt", deleteAt). Where(sq.And{ @@ -2011,6 +2009,5 @@ func (s *SqlGroupStore) buildDeleteMembersQuery(groupID string, userIDs []string sq.Eq{"UserId": userIDs}, }) - query, args, err = builder.ToSql() return } diff --git a/server/channels/store/storetest/group_store.go b/server/channels/store/storetest/group_store.go index a69948bc7e7..0af5151215f 100644 --- a/server/channels/store/storetest/group_store.go +++ b/server/channels/store/storetest/group_store.go @@ -34,6 +34,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("Update", func(t *testing.T) { testGroupStoreUpdate(t, rctx, ss) }) t.Run("Delete", func(t *testing.T) { testGroupStoreDelete(t, rctx, ss) }) t.Run("Restore", func(t *testing.T) { testGroupStoreRestore(t, rctx, ss) }) + t.Run("ToModelChannelAssociations", func(t *testing.T) { testGroupStoreToModelChannelAssociations(t, rctx, ss) }) t.Run("GetMemberUsers", func(t *testing.T) { testGroupGetMemberUsers(t, rctx, ss) }) t.Run("GetMemberUsersPage", func(t *testing.T) { testGroupGetMemberUsersPage(t, rctx, ss) }) @@ -50,6 +51,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("CreateGroupSyncable", func(t *testing.T) { testCreateGroupSyncable(t, rctx, ss) }) t.Run("GetGroupSyncable", func(t *testing.T) { testGetGroupSyncable(t, rctx, ss) }) + t.Run("GetGroupSyncableErrors", func(t *testing.T) { testGetGroupSyncableErrors(t, rctx, ss) }) t.Run("GetAllGroupSyncablesByGroupId", func(t *testing.T) { testGetAllGroupSyncablesByGroup(t, rctx, ss) }) t.Run("UpdateGroupSyncable", func(t *testing.T) { testUpdateGroupSyncable(t, rctx, ss) }) t.Run("DeleteGroupSyncable", func(t *testing.T) { testDeleteGroupSyncable(t, rctx, ss) }) @@ -74,6 +76,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("TeamMembersMinusGroupMembers", func(t *testing.T) { testTeamMembersMinusGroupMembers(t, rctx, ss) }) t.Run("ChannelMembersMinusGroupMembers", func(t *testing.T) { testChannelMembersMinusGroupMembers(t, rctx, ss) }) + t.Run("CountMembersMinusGroupMembers", func(t *testing.T) { testCountMembersMinusGroupMembers(t, rctx, ss) }) t.Run("GetMemberCount", func(t *testing.T) { groupTestGetMemberCount(t, rctx, ss) }) @@ -1618,6 +1621,44 @@ func testGetGroupSyncable(t *testing.T, rctx request.CTX, ss store.Store) { require.Zero(t, gt1.DeleteAt) } +func testGetGroupSyncableErrors(t *testing.T, rctx request.CTX, ss store.Store) { + // Create a group + g1 := &model.Group{ + Name: model.NewPointer(model.NewId()), + DisplayName: model.NewId(), + Description: model.NewId(), + Source: model.GroupSourceLdap, + RemoteId: model.NewPointer(model.NewId()), + } + group, err := ss.Group().Create(g1) + require.NoError(t, err) + + // Test with invalid syncable type + invalidSyncableType := model.GroupSyncableType("invalid") + _, err = ss.Group().GetGroupSyncable(group.Id, model.NewId(), invalidSyncableType) + require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err) + + // Test with empty group ID + _, err = ss.Group().GetGroupSyncable("", model.NewId(), model.GroupSyncableTypeTeam) + require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err) + + // Test with empty syncable ID + _, err = ss.Group().GetGroupSyncable(group.Id, "", model.GroupSyncableTypeTeam) + require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err) + + // Test with completely non-existent IDs + randomGroupId := model.NewId() + randomTeamId := model.NewId() + _, err = ss.Group().GetGroupSyncable(randomGroupId, randomTeamId, model.GroupSyncableTypeTeam) + require.True(t, errors.As(err, &nfErr), "expected ErrNotFound, got %v", err) + + // Test with valid group ID but non-existent syncable + _, err = ss.Group().GetGroupSyncable(group.Id, model.NewId(), model.GroupSyncableTypeTeam) + require.True(t, errors.As(err, &nfErr)) +} + func testGetAllGroupSyncablesByGroup(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("team", func(t *testing.T) { testGetAllGroupSyncablesByGroupTeam(t, rctx, ss) }) t.Run("channel", func(t *testing.T) { testGetAllGroupSyncablesByGroupChannel(t, rctx, ss) }) @@ -4174,17 +4215,14 @@ func testGetGroups(t *testing.T, rctx request.CTX, ss store.Store) { PerPage: 100, Resultf: func(groups []*model.Group) bool { for _, group := range groups { - fmt.Println(group.Id, group.ChannelMemberCount) var channelMemberCount int if group.ChannelMemberCount != nil { channelMemberCount = *group.ChannelMemberCount } if group.Id == group1.Id && channelMemberCount != 2 { - fmt.Println("group1", group.Id, channelMemberCount) return false } if group.Id == group2.Id && channelMemberCount != 1 { - fmt.Println("group2", group.Id, channelMemberCount) return false } } @@ -4205,11 +4243,9 @@ func testGetGroups(t *testing.T, rctx request.CTX, ss store.Store) { channelMemberCount = *group.ChannelMemberCount } if group.Id == group1.Id && channelMemberCount != 1 { - fmt.Println("group1", group.Id, channelMemberCount) return false } if group.Id == group2.Id && channelMemberCount != 2 { - fmt.Println("group2", group.Id, channelMemberCount) return false } } @@ -5800,3 +5836,266 @@ func groupTestGroupCountBySource(t *testing.T, rctx request.CTX, ss store.Store) require.NoError(t, err) require.Equal(t, ldapSourceCountAfter-1, ldapSourceCountAfterDelete) } + +func testCountMembersMinusGroupMembers(t *testing.T, rctx request.CTX, ss store.Store) { + // Create test users + u1, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: model.NewUsername(), + }) + require.NoError(t, err) + + u2, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: model.NewUsername(), + }) + require.NoError(t, err) + + u3, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: model.NewUsername(), + }) + require.NoError(t, err) + + // Create test team + team := &model.Team{ + DisplayName: "Name", + Description: "Some description", + CompanyName: "Some company name", + AllowOpenInvite: false, + InviteId: "inviteid0", + Name: "z-z-" + model.NewId() + "a", + Email: "success+" + model.NewId() + "@simulator.amazonses.com", + Type: model.TeamOpen, + } + team, err = ss.Team().Save(team) + require.NoError(t, err) + + // Create a second team to test team-specific counts + team2 := &model.Team{ + DisplayName: "Name 2", + Description: "Some description 2", + CompanyName: "Some company name 2", + AllowOpenInvite: false, + InviteId: "inviteid1", + Name: "z-z-" + model.NewId() + "b", + Email: "success+" + model.NewId() + "@simulator.amazonses.com", + Type: model.TeamOpen, + } + team2, err = ss.Team().Save(team2) + require.NoError(t, err) + + // Create test channel in team 1 + channel := &model.Channel{ + TeamId: team.Id, + DisplayName: "Display Name", + Name: "z-z-" + model.NewId() + "a", + Type: model.ChannelTypeOpen, + } + channel, nErr := ss.Channel().Save(rctx, channel, -1) + require.NoError(t, nErr) + + // Create test channel in team 2 + channel2 := &model.Channel{ + TeamId: team2.Id, + DisplayName: "Display Name 2", + Name: "z-z-" + model.NewId() + "b", + Type: model.ChannelTypeOpen, + } + channel2, nErr = ss.Channel().Save(rctx, channel2, -1) + require.NoError(t, nErr) + + // Add users to teams and channels + // u1 and u2 in team 1 + _, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: u1.Id}, -1) + require.NoError(t, nErr) + _, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: u2.Id}, -1) + require.NoError(t, nErr) + + // u3 in team 2 + _, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team2.Id, UserId: u3.Id}, -1) + require.NoError(t, nErr) + + // u1 and u2 in channel 1 + _, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel.Id, UserId: u1.Id, NotifyProps: model.GetDefaultChannelNotifyProps()}) + require.NoError(t, nErr) + _, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel.Id, UserId: u2.Id, NotifyProps: model.GetDefaultChannelNotifyProps()}) + require.NoError(t, nErr) + + // u3 in channel 2 + _, nErr = ss.Channel().SaveMember(rctx, &model.ChannelMember{ChannelId: channel2.Id, UserId: u3.Id, NotifyProps: model.GetDefaultChannelNotifyProps()}) + require.NoError(t, nErr) + + // Create groups + group1, err := ss.Group().Create(&model.Group{ + Name: model.NewPointer(model.NewId()), + DisplayName: model.NewId(), + Description: model.NewId(), + Source: model.GroupSourceCustom, + RemoteId: model.NewPointer(model.NewId()), + }) + require.NoError(t, err) + + group2, err := ss.Group().Create(&model.Group{ + Name: model.NewPointer(model.NewId()), + DisplayName: model.NewId(), + Description: model.NewId(), + Source: model.GroupSourceCustom, + RemoteId: model.NewPointer(model.NewId()), + }) + require.NoError(t, err) + + // Add u1 to group1, u3 to group2 + _, err = ss.Group().UpsertMember(group1.Id, u1.Id) + require.NoError(t, err) + _, err = ss.Group().UpsertMember(group2.Id, u3.Id) + require.NoError(t, err) + + // Test CountTeamMembersMinusGroupMembers with empty groupIDs + count, err := ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{}) + require.NoError(t, err) + require.Equal(t, int64(2), count) // Both u1 and u2 are counted when no groups are excluded + + // Test CountTeamMembersMinusGroupMembers with group1 ID + count, err = ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{group1.Id}) + require.NoError(t, err) + require.Equal(t, int64(1), count) // Only u2 should be counted (not in the group) + + // Test with non-existent team ID + count, err = ss.Group().CountTeamMembersMinusGroupMembers(model.NewId(), []string{group1.Id}) + require.NoError(t, err) + require.Equal(t, int64(0), count) // No members in a non-existent team + + // Test with multiple group IDs + count, err = ss.Group().CountTeamMembersMinusGroupMembers(team.Id, []string{group1.Id, group2.Id}) + require.NoError(t, err) + require.Equal(t, int64(1), count) // Only u2 should be counted + + // Test team 2 + count, err = ss.Group().CountTeamMembersMinusGroupMembers(team2.Id, []string{group2.Id}) + require.NoError(t, err) + require.Equal(t, int64(0), count) // No one should be counted (u3 is in group2) + + // Test CountChannelMembersMinusGroupMembers with empty groupIDs + count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{}) + require.NoError(t, err) + require.Equal(t, int64(2), count) // Both users are counted when no groups are excluded + + // Test CountChannelMembersMinusGroupMembers with the group ID + count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{group1.Id}) + require.NoError(t, err) + require.Equal(t, int64(1), count) // Only u2 should be counted (not in the group) + + // Test with multiple group IDs + count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel.Id, []string{group1.Id, group2.Id}) + require.NoError(t, err) + require.Equal(t, int64(1), count) // u2 should be counted + + // Test channel 2 + count, err = ss.Group().CountChannelMembersMinusGroupMembers(channel2.Id, []string{group2.Id}) + require.NoError(t, err) + require.Equal(t, int64(0), count) // No one should be counted (u3 is in group2) + + // Test with non-existent channel ID + count, err = ss.Group().CountChannelMembersMinusGroupMembers(model.NewId(), []string{group1.Id}) + require.NoError(t, err) + require.Equal(t, int64(0), count) // No members in a non-existent channel + + // Test error cases - passing invalid parameters + // 1. Empty team ID + count, err = ss.Group().CountTeamMembersMinusGroupMembers("", []string{group1.Id}) + require.NoError(t, err) // Should handle this gracefully + require.Equal(t, int64(0), count) + + // 2. Empty channel ID + count, err = ss.Group().CountChannelMembersMinusGroupMembers("", []string{group1.Id}) + require.NoError(t, err) // Should handle this gracefully + require.Equal(t, int64(0), count) +} + +func testGroupStoreToModelChannelAssociations(t *testing.T, rctx request.CTX, ss store.Store) { + // Create test group + group, err := ss.Group().Create(&model.Group{ + Name: model.NewPointer(model.NewId()), + DisplayName: model.NewId(), + Description: model.NewId(), + Source: model.GroupSourceCustom, + RemoteId: model.NewPointer(model.NewId()), + }) + require.NoError(t, err) + require.NotNil(t, group) + + // Create test team + team := &model.Team{ + DisplayName: "Name", + Description: "Some description", + CompanyName: "Some company name", + AllowOpenInvite: false, + InviteId: "inviteid0", + Name: "z-z-" + model.NewId() + "a", + Email: "success+" + model.NewId() + "@simulator.amazonses.com", + Type: model.TeamOpen, + } + team, err = ss.Team().Save(team) + require.NoError(t, err) + require.NotNil(t, team) + + // Create test channel 1 + channel1 := &model.Channel{ + TeamId: team.Id, + DisplayName: "Display Name 1", + Name: "z-z-" + model.NewId() + "a", + Type: model.ChannelTypeOpen, + } + channel1, nErr := ss.Channel().Save(rctx, channel1, -1) + require.NoError(t, nErr) + require.NotNil(t, channel1) + + // Create test channel 2 + channel2 := &model.Channel{ + TeamId: team.Id, + DisplayName: "Display Name 2", + Name: "z-z-" + model.NewId() + "b", + Type: model.ChannelTypeOpen, + } + channel2, nErr = ss.Channel().Save(rctx, channel2, -1) + require.NoError(t, nErr) + require.NotNil(t, channel2) + + // Create group channel syncables + _, err = ss.Group().CreateGroupSyncable(&model.GroupSyncable{ + GroupId: group.Id, + SyncableId: channel1.Id, + Type: model.GroupSyncableTypeChannel, + SchemeAdmin: true, + }) + require.NoError(t, err) + + _, err = ss.Group().CreateGroupSyncable(&model.GroupSyncable{ + GroupId: group.Id, + SyncableId: channel2.Id, + Type: model.GroupSyncableTypeChannel, + SchemeAdmin: false, + }) + require.NoError(t, err) + + // Test the GetGroupsAssociatedToChannelsByTeam function + // This exercises the groupsAssociatedToChannelWithSchemeAdmin.ToModel method + result, err := ss.Group().GetGroupsAssociatedToChannelsByTeam(team.Id, model.GroupSearchOpts{}) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify channel 1 results + require.Contains(t, result, channel1.Id) + require.NotEmpty(t, result[channel1.Id]) + require.Equal(t, group.Id, result[channel1.Id][0].Id) + require.NotNil(t, result[channel1.Id][0].SchemeAdmin) + require.True(t, *result[channel1.Id][0].SchemeAdmin) + + // Verify channel 2 results (with different SchemeAdmin value) + require.Contains(t, result, channel2.Id) + require.NotEmpty(t, result[channel2.Id]) + require.Equal(t, group.Id, result[channel2.Id][0].Id) + require.NotNil(t, result[channel2.Id][0].SchemeAdmin) + require.False(t, *result[channel2.Id][0].SchemeAdmin) +}