diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index 79f0c795982..97e33795827 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -1772,7 +1772,7 @@ func (a *App) addUserToChannel(rctx request.CTX, user *model.User, channel *mode // Synchronize membership change for shared channels if channel.IsShared() { if scs := a.Srv().Platform().GetSharedChannelService(); scs != nil { - scs.HandleMembershipChange(channel.Id, user.Id, true, user.GetRemoteID()) + scs.NotifyMembershipChanged(channel.Id, user.GetRemoteID()) } } @@ -2844,9 +2844,8 @@ func (a *App) removeUserFromChannel(rctx request.CTX, userIDToRemove string, rem // Synchronize membership change for shared channels if channel.IsShared() { - // isAdd=false, empty remoteId means locally initiated if scs := a.Srv().Platform().GetSharedChannelService(); scs != nil { - scs.HandleMembershipChange(channel.Id, userIDToRemove, false, "") + scs.NotifyMembershipChanged(channel.Id, "") } } diff --git a/server/channels/app/platform/shared_channel_service_iface.go b/server/channels/app/platform/shared_channel_service_iface.go index 2db9152dea3..ab76329f728 100644 --- a/server/channels/app/platform/shared_channel_service_iface.go +++ b/server/channels/app/platform/shared_channel_service_iface.go @@ -24,7 +24,7 @@ type SharedChannelServiceIFace interface { CheckChannelNotShared(channelID string) error CheckChannelIsShared(channelID string) error CheckCanInviteToSharedChannel(channelId string) error - HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) + NotifyMembershipChanged(channelID string, originRemoteID string) TransformMentionsOnReceiveForTesting(rctx request.CTX, post *model.Post, targetChannel *model.Channel, rc *model.RemoteCluster, mentionTransforms map[string]string) } @@ -81,7 +81,7 @@ func (mrcs *mockSharedChannelService) NumInvitations() int { return mrcs.numInvitations } -func (mrcs *mockSharedChannelService) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { +func (mrcs *mockSharedChannelService) NotifyMembershipChanged(channelID string, originRemoteID string) { // This is a mock implementation - it doesn't need to do anything } diff --git a/server/channels/app/shared_channel_membership_sync_self_referential_test.go b/server/channels/app/shared_channel_membership_sync_self_referential_test.go index afab4acfcb1..542fdb7add9 100644 --- a/server/channels/app/shared_channel_membership_sync_self_referential_test.go +++ b/server/channels/app/shared_channel_membership_sync_self_referential_test.go @@ -29,7 +29,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { ss := th.App.Srv().Store() - // Get the shared channel service and cast to concrete type to access SyncAllChannelMembers + // Get the shared channel service and cast to concrete type scsInterface := th.App.Srv().GetSharedChannelSyncService() service, ok := scsInterface.(*sharedchannel.Service) require.True(t, ok, "Expected sharedchannel.Service concrete type") @@ -61,7 +61,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { t.Run("Test 1: Automatic sync on membership changes", func(t *testing.T) { // This test verifies that membership sync happens automatically when users are added or removed from a shared channel. - // The sync is triggered by HandleMembershipChange which is called automatically by AddUserToChannel and RemoveUserFromChannel. + // The sync is triggered by NotifyMembershipChanged which is called automatically by AddUserToChannel and RemoveUserFromChannel. // The test ensures that sync messages are sent asynchronously after a minimum delay for both add and remove operations. EnsureCleanState(t, th, ss) // Track sync messages received @@ -132,7 +132,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { _, _, appErr = th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, user.Id, th.BasicUser.Id) require.Nil(t, appErr) - // Add user to channel - this triggers HandleMembershipChange automatically + // Add user to channel - this triggers NotifyMembershipChanged automatically _, appErr = th.App.AddUserToChannel(th.Context, user, channel, false) require.Nil(t, appErr) @@ -307,8 +307,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { mu.Unlock() } - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for batch messages to be received with more robust checking require.Eventually(t, func() bool { @@ -343,11 +342,10 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } mu.Unlock() - // Verify exact batch count - assert.Equal(t, expectedBatches, actualBatches, fmt.Sprintf("Should have exactly %d batches with batch size %d", expectedBatches, batchSize)) - - // Verify total synced users - assert.Equal(t, expectedTotal, totalSynced, "All users including bots and system admins should be synced") + // With inclusive cursor (GtOrEq), boundary rows may be re-fetched across + // batch boundaries, so batch count and total may exceed the minimum. + assert.GreaterOrEqual(t, actualBatches, expectedBatches, fmt.Sprintf("Should have at least %d batches with batch size %d", expectedBatches, batchSize)) + assert.GreaterOrEqual(t, totalSynced, expectedTotal, "All users including bots and system admins should be synced") // Verify that bot and system admin WERE synced assert.Contains(t, allSyncedUserIDs, bot.UserId, "Bot should be synced") @@ -468,8 +466,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } // First sync - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync to complete require.Eventually(t, func() bool { @@ -497,8 +494,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, appErr) // Second sync - should only sync user3 - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for second sync to complete require.Eventually(t, func() bool { @@ -518,11 +514,14 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { return secondSyncCursor > firstSyncCursor }, 10*time.Second, 100*time.Millisecond, "Cursor should advance after second sync") - // Verify incremental sync + // Verify incremental sync: first sync includes initial users, second + // sync must include user3. The inclusive cursor (GtOrEq) may re-fetch + // events at the boundary timestamp, so users from the first batch can + // appear again. This is the expected trade-off to avoid data loss at + // batch boundaries without requiring a composite cursor schema change. + // The receiver is idempotent so duplicates are harmless. assert.GreaterOrEqual(t, len(syncedInFirstCall), 2, "First sync should include initial users") - assert.Contains(t, syncedInSecondCall, user3.Id, "Second sync should include only new user") - assert.NotContains(t, syncedInSecondCall, user1.Id, "Second sync should not re-sync existing users") - assert.NotContains(t, syncedInSecondCall, user2.Id, "Second sync should not re-sync existing users") + assert.Contains(t, syncedInSecondCall, user3.Id, "Second sync must include the new user") }) t.Run("Test 4: Sync failure and recovery", func(t *testing.T) { t.Skip("MM-64687") @@ -618,8 +617,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, appErr) // First sync attempt - will fail - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync attempt with more robust checking require.Eventually(t, func() bool { @@ -635,8 +633,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { failureMode.Store(false) // Second sync attempt - should succeed - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for successful sync with more robust checking require.Eventually(t, func() bool { @@ -648,7 +645,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { assert.Greater(t, finalAttempts, initialAttempts, "Should have retried after recovery") }) t.Run("Test 5: Manual sync with cursor management", func(t *testing.T) { - // This test verifies manual sync using SyncAllChannelMembers with complete cursor management: + // This test verifies manual sync using NotifyMembershipChanged with complete cursor management: // 1. Initial sync of 10 users with cursor tracking // 2. Mixed operations: remove 3 users and add 5 new users // 3. Verifies all operations are properly synced and cursor is updated correctly @@ -676,8 +673,8 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { var frame model.RemoteClusterFrame if json.Unmarshal(bodyBytes, &frame) == nil { var syncMsg model.SyncMsg - if json.Unmarshal(frame.Msg.Payload, &syncMsg) == nil && frame.Msg.Topic == "sharedchannel_membership" { - // Count membership changes from the unified field + if json.Unmarshal(frame.Msg.Payload, &syncMsg) == nil { + // Count membership changes (now sent via TopicSync) for _, change := range syncMsg.MembershipChanges { if change.IsAdd { atomic.AddInt32(&addOperations, 1) @@ -762,8 +759,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { initialCursor := initialScr.LastMembersSyncAt // Initial sync - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for initial sync to complete require.Eventually(t, func() bool { @@ -804,8 +800,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { // Sync mixed changes previousMessages := atomic.LoadInt32(&totalSyncMessages) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for mixed changes sync to complete require.Eventually(t, func() bool { @@ -963,28 +958,25 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, addErr) } - // Sync to all clusters - need to sync each one individually - for _, cluster := range clusters { - err = service.SyncAllChannelMembers(channel.Id, cluster.RemoteId, nil) - require.NoError(t, err) - } + // Sync to all clusters via NotifyMembershipChanged + service.NotifyMembershipChanged(channel.Id, "") // Wait for syncs to complete require.Eventually(t, func() bool { - // Each cluster should receive at least 5 sync messages (one per user) + // Each cluster should receive at least 1 sync message (membership changes are batched) for _, countPtr := range syncMessagesPerCluster { - if atomic.LoadInt32(countPtr) < 5 { + if atomic.LoadInt32(countPtr) < 1 { return false } } return true }, 10*time.Second, 100*time.Millisecond, "All clusters should receive sync messages") - // Verify each cluster received messages + // Verify each cluster received messages (membership changes are batched, so >= 1) for name, countPtr := range syncMessagesPerCluster { finalCount := atomic.LoadInt32(countPtr) - assert.GreaterOrEqual(t, finalCount, int32(5), - "Cluster %s should receive at least 5 sync messages", name) + assert.GreaterOrEqual(t, finalCount, int32(1), + "Cluster %s should receive at least 1 sync message", name) } // Part 2: Test propagation from one cluster through another @@ -1061,11 +1053,8 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { appErr = th.App.RemoveUserFromChannel(th.Context, users[0].Id, th.SystemAdminUser.Id, channel) require.Nil(t, appErr) - // Sync removal to all clusters - for _, cluster := range clusters { - err = service.SyncAllChannelMembers(channel.Id, cluster.RemoteId, nil) - require.NoError(t, err) - } + // Sync removal to all clusters via NotifyMembershipChanged + service.NotifyMembershipChanged(channel.Id, "") // Wait for removal sync require.Eventually(t, func() bool { @@ -1088,7 +1077,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { t.Run("Test 7: Feature flag disabled", func(t *testing.T) { // This test verifies that the shared channel membership sync functionality respects the feature flag. // It tests two scenarios: - // 1. When the feature flag is disabled, no sync messages should be sent even when SyncAllChannelMembers is called + // 1. When the feature flag is disabled, no sync messages should be sent even when NotifyMembershipChanged is called // 2. When the feature flag is enabled, sync messages should be sent as expected // This ensures that the feature can be safely disabled in production without triggering unintended syncs EnsureCleanState(t, th, ss) @@ -1163,8 +1152,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } atomic.StoreInt32(&syncMessageCount, 0) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify no sync messages were sent require.Never(t, func() bool { @@ -1177,8 +1165,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { }) atomic.StoreInt32(&syncMessageCount, 0) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify sync messages were sent require.Eventually(t, func() bool { @@ -1270,8 +1257,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.NoError(t, err) // Trigger membership sync as would happen when connection is restored - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify sync task was created and executed with more generous timeout require.Eventually(t, func() bool { @@ -1387,8 +1373,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { syncHandler = NewSelfReferentialSyncHandler(t, service, selfCluster) // First sync should succeed - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync with more generous timeout require.Eventually(t, func() bool { @@ -1416,8 +1401,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } // Second sync should fail (server goes offline) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) // Method itself shouldn't error + service.NotifyMembershipChanged(channel.Id, "") // Wait for second sync attempt with more generous timeout require.Eventually(t, func() bool { @@ -1571,8 +1555,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { // Sync memberships for each channel separately for _, channel := range []*model.Channel{channel1, channel2, channel3} { - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") } // Ensure the sync handler is ready by waiting for the first message diff --git a/server/channels/app/shared_channel_service_iface.go b/server/channels/app/shared_channel_service_iface.go index 2a6ecfa682c..52e2b635eb6 100644 --- a/server/channels/app/shared_channel_service_iface.go +++ b/server/channels/app/shared_channel_service_iface.go @@ -27,7 +27,7 @@ type SharedChannelServiceIFace interface { CheckChannelNotShared(channelID string) error CheckChannelIsShared(channelID string) error CheckCanInviteToSharedChannel(channelId string) error - HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) + NotifyMembershipChanged(channelID string, originRemoteID string) IsRemoteClusterDirectlyConnected(remoteId string) bool TransformMentionsOnReceiveForTesting(rctx request.CTX, post *model.Post, targetChannel *model.Channel, rc *model.RemoteCluster, mentionTransforms map[string]string) } @@ -37,6 +37,7 @@ func NewMockSharedChannelService(service SharedChannelServiceIFace) *mockSharedC SharedChannelServiceIFace: service, channelNotifications: []string{}, userProfileNotifications: []string{}, + membershipNotifications: []string{}, numInvitations: 0, } return mrcs @@ -46,6 +47,7 @@ type mockSharedChannelService struct { SharedChannelServiceIFace channelNotifications []string userProfileNotifications []string + membershipNotifications []string numInvitations int } @@ -96,9 +98,10 @@ func (mrcs *mockSharedChannelService) NumInvitations() int { return mrcs.numInvitations } -func (mrcs *mockSharedChannelService) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { +func (mrcs *mockSharedChannelService) NotifyMembershipChanged(channelID string, originRemoteID string) { + mrcs.membershipNotifications = append(mrcs.membershipNotifications, channelID) if mrcs.SharedChannelServiceIFace != nil { - mrcs.SharedChannelServiceIFace.HandleMembershipChange(channelID, userID, isAdd, remoteID) + mrcs.SharedChannelServiceIFace.NotifyMembershipChanged(channelID, originRemoteID) } } diff --git a/server/channels/app/shared_channel_sync_self_referential_utils_test.go b/server/channels/app/shared_channel_sync_self_referential_utils_test.go index d0944f66229..1b4beeeb420 100644 --- a/server/channels/app/shared_channel_sync_self_referential_utils_test.go +++ b/server/channels/app/shared_channel_sync_self_referential_utils_test.go @@ -129,8 +129,10 @@ func (h *SelfReferentialSyncHandler) HandleRequest(w http.ResponseWriter, r *htt if h.OnBatchSync != nil { h.OnBatchSync(batch, currentCall) } - if len(batch) == 1 && h.OnIndividualSync != nil { - h.OnIndividualSync(batch[0], currentCall) + if h.OnIndividualSync != nil { + for _, uid := range batch { + h.OnIndividualSync(uid, currentCall) + } } } } diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 1978557da27..5fdd6bb9f17 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -3869,6 +3869,27 @@ func (s *RetryLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star } +func (s *RetryLayerChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + + tries := 0 + for { + result, err := s.ChannelMemberHistoryStore.GetMembershipChanges(channelID, since, limit) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { tries := 0 @@ -12707,27 +12728,6 @@ func (s *RetryLayerSharedChannelStore) GetSingleUser(userID string, channelID st } -func (s *RetryLayerSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - - tries := 0 - for { - result, err := s.SharedChannelStore.GetUserChanges(userID, channelID, afterTime) - if err == nil { - return result, nil - } - if !isRepeatableError(err) { - return result, err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return result, err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - func (s *RetryLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { tries := 0 @@ -13001,27 +13001,6 @@ func (s *RetryLayerSharedChannelStore) UpdateRemoteMembershipCursor(id string, s } -func (s *RetryLayerSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - - tries := 0 - for { - err := s.SharedChannelStore.UpdateUserLastMembershipSyncAt(userID, channelID, remoteID, syncTime) - if err == nil { - return nil - } - if !isRepeatableError(err) { - return err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - func (s *RetryLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { tries := 0 diff --git a/server/channels/store/sqlstore/channel_member_history_store.go b/server/channels/store/sqlstore/channel_member_history_store.go index 7665297c379..b4f8e6c9753 100644 --- a/server/channels/store/sqlstore/channel_member_history_store.go +++ b/server/channels/store/sqlstore/channel_member_history_store.go @@ -295,6 +295,40 @@ func (s SqlChannelMemberHistoryStore) PermanentDeleteBatch(endTime int64, limit return rowsAffected, nil } +// GetMembershipChanges returns all membership events (joins and leaves) for a channel since the given timestamp. +// Uses inclusive comparison (>=) so events at the cursor timestamp are re-fetched rather than lost at batch +// boundaries. This may cause redundant re-sends when consecutive batches share a boundary timestamp, but the +// receiver is idempotent so duplicates are harmless. A composite cursor (timestamp + ID) like posts use would +// eliminate duplicates, but would require a schema change to SharedChannelRemotes; the current trade-off avoids that. +func (s SqlChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + query, args, err := s.getQueryBuilder(). + Select("ChannelId", "UserId", "JoinTime", "LeaveTime"). + From("ChannelMemberHistory"). + Where(sq.And{ + sq.Eq{"ChannelId": channelID}, + sq.Or{ + sq.GtOrEq{"JoinTime": since}, + sq.And{ + sq.NotEq{"LeaveTime": nil}, + sq.GtOrEq{"LeaveTime": since}, + }, + }, + }). + OrderBy("GREATEST(JoinTime, COALESCE(LeaveTime, 0)) ASC", "UserId ASC"). + Limit(uint64(limit)). + ToSql() + if err != nil { + return nil, errors.Wrap(err, "channel_member_history_to_sql") + } + + histories := []*model.ChannelMemberHistory{} + if err := s.GetReplica().Select(&histories, query, args...); err != nil { + return nil, errors.Wrapf(err, "GetMembershipChanges channelId=%s since=%d limit=%d", channelID, since, limit) + } + + return histories, nil +} + // GetChannelsLeftSince returns list of channels that the user has left after a given time, // but has not rejoined again. func (s SqlChannelMemberHistoryStore) GetChannelsLeftSince(userID string, since int64) ([]string, error) { diff --git a/server/channels/store/sqlstore/shared_channel_store.go b/server/channels/store/sqlstore/shared_channel_store.go index b6047b5c3c3..f748a438c20 100644 --- a/server/channels/store/sqlstore/shared_channel_store.go +++ b/server/channels/store/sqlstore/shared_channel_store.go @@ -711,7 +711,6 @@ func sharedChannelUserFields(prefix string) []string { prefix + "RemoteId", prefix + "CreateAt", prefix + "LastSyncAt", - prefix + "LastMembershipSyncAt", } } @@ -724,7 +723,7 @@ func (s SqlSharedChannelStore) SaveUser(scUser *model.SharedChannelUser) (*model query, args, err := s.getQueryBuilder().Insert("SharedChannelUsers"). Columns(sharedChannelUserFields("")...). - Values(scUser.Id, scUser.UserId, scUser.ChannelId, scUser.RemoteId, scUser.CreateAt, scUser.LastSyncAt, scUser.LastMembershipSyncAt). + Values(scUser.Id, scUser.UserId, scUser.ChannelId, scUser.RemoteId, scUser.CreateAt, scUser.LastSyncAt). ToSql() if err != nil { return nil, errors.Wrapf(err, "savesharedchanneluser_tosql") @@ -852,25 +851,6 @@ func (s SqlSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID str return nil } -// UpdateUserLastMembershipSyncAt updates the LastMembershipSyncAt timestamp for the specified SharedChannelUser using the provided sync time. -func (s SqlSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - query := s.getQueryBuilder(). - Update("SharedChannelUsers AS scu"). - Set("LastMembershipSyncAt", sq.Expr("GREATEST(scu.LastMembershipSyncAt, ?)", syncTime)). - Where(sq.Eq{ - "scu.UserId": userID, - "scu.ChannelId": channelID, - "scu.RemoteId": remoteID, - }) - - _, err := s.GetMaster().ExecBuilder(query) - if err != nil { - return fmt.Errorf("failed to update LastMembershipSyncAt for SharedChannelUser with userId=%s, channelId=%s, remoteId=%s: %w", - userID, channelID, remoteID, err) - } - return nil -} - func sharedChannelAttachementFields(prefix string) []string { if prefix != "" && !strings.HasSuffix(prefix, ".") { prefix = prefix + "." diff --git a/server/channels/store/sqlstore/shared_channel_store_membership.go b/server/channels/store/sqlstore/shared_channel_store_membership.go index 111fbb85e01..0ce4ee44d58 100644 --- a/server/channels/store/sqlstore/shared_channel_store_membership.go +++ b/server/channels/store/sqlstore/shared_channel_store_membership.go @@ -4,10 +4,8 @@ package sqlstore import ( - "database/sql" "fmt" - "github.com/mattermost/mattermost/server/public/model" sq "github.com/mattermost/squirrel" "github.com/pkg/errors" ) @@ -38,29 +36,3 @@ func (s SqlSharedChannelStore) UpdateRemoteMembershipCursor(id string, syncTime return nil } - -// GetUserChanges gets all SharedChannelUser changes for a given user, channel after a specific time. -// This is used to detect if there are conflicting membership changes. -func (s SqlSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - squery, args, err := s.getQueryBuilder(). - Select(sharedChannelUserFields("")...). - From("SharedChannelUsers"). - Where(sq.Eq{"SharedChannelUsers.UserId": userID}). - Where(sq.Eq{"SharedChannelUsers.ChannelId": channelID}). - Where(sq.Gt{"SharedChannelUsers.LastSyncAt": afterTime}). - ToSql() - - if err != nil { - return nil, errors.Wrapf(err, "getsharedchanneluserchanges_tosql") - } - - users := []*model.SharedChannelUser{} - if err := s.GetReplica().Select(&users, squery, args...); err != nil { - if err == sql.ErrNoRows { - return make([]*model.SharedChannelUser, 0), nil - } - return nil, errors.Wrapf(err, "failed to find shared channel user changes with UserId=%s, ChannelId=%s, afterTime=%d", - userID, channelID, afterTime) - } - return users, nil -} diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 789b3236a1d..ae5a3825081 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -334,6 +334,7 @@ type ChannelMemberHistoryStore interface { DeleteOrphanedRows(limit int) (deleted int64, err error) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) GetChannelsLeftSince(userID string, since int64) ([]string, error) + GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) } type ThreadStore interface { GetThreadFollowers(threadID string, fetchOnlyActive bool) ([]string, error) @@ -1040,9 +1041,7 @@ type SharedChannelStore interface { GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) - GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error - UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) diff --git a/server/channels/store/storetest/channel_member_history_store.go b/server/channels/store/storetest/channel_member_history_store.go index aaae0e34ab1..05b339e9c88 100644 --- a/server/channels/store/storetest/channel_member_history_store.go +++ b/server/channels/store/storetest/channel_member_history_store.go @@ -26,6 +26,7 @@ func TestChannelMemberHistoryStore(t *testing.T, rctx request.CTX, ss store.Stor t.Run("TestPermanentDeleteBatchForRetentionPolicies", func(t *testing.T) { testPermanentDeleteBatchForRetentionPolicies(t, rctx, ss) }) t.Run("TestGetChannelsLeftSince", func(t *testing.T) { testGetChannelsLeftSince(t, rctx, ss) }) t.Run("TestDeleteOrphanedRows", func(t *testing.T) { testDeleteOrphanedRows(t, rctx, ss) }) + t.Run("TestGetMembershipChanges", func(t *testing.T) { testGetMembershipChanges(t, rctx, ss) }) } func testLogJoinEvent(t *testing.T, rctx request.CTX, ss store.Store) { @@ -689,3 +690,124 @@ func testDeleteOrphanedRows(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) require.Equal(t, int64(0), deletedCount, "No rows should be deleted when no orphans exist") } + +func testGetMembershipChanges(t *testing.T, rctx request.CTX, ss store.Store) { + ch := model.Channel{ + TeamId: model.NewId(), + DisplayName: "GetMembershipChanges", + Name: NewTestID(), + Type: model.ChannelTypeOpen, + } + channel, nErr := ss.Channel().Save(rctx, &ch, 100) + require.NoError(t, nErr) + + user1 := model.NewId() + user2 := model.NewId() + user3 := model.NewId() + + // Set up timeline: + // t=1000: user1 joins + // t=2000: user2 joins + // t=3000: user3 joins + // t=4000: user1 leaves + // t=5000: user2 leaves + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user1, channel.Id, 1000)) + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user2, channel.Id, 2000)) + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user3, channel.Id, 3000)) + require.NoError(t, ss.ChannelMemberHistory().LogLeaveEvent(user1, channel.Id, 4000)) + require.NoError(t, ss.ChannelMemberHistory().LogLeaveEvent(user2, channel.Id, 5000)) + + t.Run("returns all events since timestamp zero", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + assert.Len(t, results, 3, "should return all 3 users' history rows") + }) + + t.Run("filters by since timestamp - joins only", func(t *testing.T) { + // since=2500: user3 joined at 3000 (>=2500), user1 left at 4000 (>=2500), user2 left at 5000 (>=2500) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 2500, 100) + require.NoError(t, err) + assert.Len(t, results, 3, "user3 join>=2500, user1 leave>=2500, user2 leave>=2500") + }) + + t.Run("inclusive boundary includes events at exact cursor timestamp", func(t *testing.T) { + // since=3000: user3 joined at exactly 3000, should be included (GtOrEq) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 3000, 100) + require.NoError(t, err) + userIDs := make(map[string]bool) + for _, r := range results { + userIDs[r.UserId] = true + } + assert.True(t, userIDs[user3], "user3 joined at exactly since=3000, should be included") + assert.True(t, userIDs[user1], "user1 left at 4000 >= 3000") + assert.True(t, userIDs[user2], "user2 left at 5000 >= 3000") + }) + + t.Run("filters by since timestamp - leaves only", func(t *testing.T) { + // since=3500: user1 left at 4000, user2 left at 5000; user3 joined at 3000 (not > 3500) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 3500, 100) + require.NoError(t, err) + assert.Len(t, results, 2, "only user1 and user2 have events after 3500") + }) + + t.Run("respects limit parameter", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 2) + require.NoError(t, err) + assert.Len(t, results, 2, "should respect limit of 2") + }) + + t.Run("returns empty for unknown channel", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(model.NewId(), 0, 100) + require.NoError(t, err) + assert.Empty(t, results) + }) + + t.Run("returns empty when since is beyond all events", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 999999, 100) + require.NoError(t, err) + assert.Empty(t, results) + }) + + t.Run("results are ordered by greatest event time ascending", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + require.Len(t, results, 3) + + // Effective event times: user1=max(1000,4000)=4000, user2=max(2000,5000)=5000, user3=max(3000,0)=3000 + // Sorted ASC: user3(3000), user1(4000), user2(5000) + assert.Equal(t, user3, results[0].UserId, "user3 should be first (effective time 3000)") + assert.Equal(t, user1, results[1].UserId, "user1 should be second (effective time 4000)") + assert.Equal(t, user2, results[2].UserId, "user2 should be third (effective time 5000)") + }) + + t.Run("null LeaveTime handled correctly", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + + for _, r := range results { + if r.UserId == user3 { + assert.Nil(t, r.LeaveTime, "user3 should have nil LeaveTime") + } + if r.UserId == user1 { + require.NotNil(t, r.LeaveTime) + assert.Equal(t, int64(4000), *r.LeaveTime) + } + } + }) + + t.Run("join-leave-rejoin cycle returns latest row", func(t *testing.T) { + // user1 rejoins at t=6000 + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user1, channel.Id, 6000)) + + // since=0 should now return 4 rows (user1 has 2 history rows, user2 has 1, user3 has 1) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + assert.Len(t, results, 4, "user1 now has 2 history rows") + + // The rejoined row (JoinTime=6000, LeaveTime=nil) should be the last one + lastResult := results[len(results)-1] + assert.Equal(t, user1, lastResult.UserId) + assert.Equal(t, int64(6000), lastResult.JoinTime) + assert.Nil(t, lastResult.LeaveTime, "rejoin should have nil LeaveTime") + }) +} diff --git a/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go b/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go index 280090253a1..c24baf67b98 100644 --- a/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go +++ b/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go @@ -102,6 +102,36 @@ func (_m *ChannelMemberHistoryStore) GetChannelsWithActivityDuring(startTime int return r0, r1 } +// GetMembershipChanges provides a mock function with given fields: channelID, since, limit +func (_m *ChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + ret := _m.Called(channelID, since, limit) + + if len(ret) == 0 { + panic("no return value specified for GetMembershipChanges") + } + + var r0 []*model.ChannelMemberHistory + var r1 error + if rf, ok := ret.Get(0).(func(string, int64, int) ([]*model.ChannelMemberHistory, error)); ok { + return rf(channelID, since, limit) + } + if rf, ok := ret.Get(0).(func(string, int64, int) []*model.ChannelMemberHistory); ok { + r0 = rf(channelID, since, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelMemberHistory) + } + } + + if rf, ok := ret.Get(1).(func(string, int64, int) error); ok { + r1 = rf(channelID, since, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetUsersInChannelDuring provides a mock function with given fields: startTime, endTime, channelID func (_m *ChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { ret := _m.Called(startTime, endTime, channelID) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index b430c1c33d3..455c68a0448 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // ChannelStore is an autogenerated mock type for the ChannelStore type diff --git a/server/channels/store/storetest/mocks/PostStore.go b/server/channels/store/storetest/mocks/PostStore.go index ff5291cb754..d8ee218923c 100644 --- a/server/channels/store/storetest/mocks/PostStore.go +++ b/server/channels/store/storetest/mocks/PostStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // PostStore is an autogenerated mock type for the PostStore type diff --git a/server/channels/store/storetest/mocks/SharedChannelStore.go b/server/channels/store/storetest/mocks/SharedChannelStore.go index dbecc0f5b43..d0906484491 100644 --- a/server/channels/store/storetest/mocks/SharedChannelStore.go +++ b/server/channels/store/storetest/mocks/SharedChannelStore.go @@ -368,36 +368,6 @@ func (_m *SharedChannelStore) GetSingleUser(userID string, channelID string, rem return r0, r1 } -// GetUserChanges provides a mock function with given fields: userID, channelID, afterTime -func (_m *SharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - ret := _m.Called(userID, channelID, afterTime) - - if len(ret) == 0 { - panic("no return value specified for GetUserChanges") - } - - var r0 []*model.SharedChannelUser - var r1 error - if rf, ok := ret.Get(0).(func(string, string, int64) ([]*model.SharedChannelUser, error)); ok { - return rf(userID, channelID, afterTime) - } - if rf, ok := ret.Get(0).(func(string, string, int64) []*model.SharedChannelUser); ok { - r0 = rf(userID, channelID, afterTime) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*model.SharedChannelUser) - } - } - - if rf, ok := ret.Get(1).(func(string, string, int64) error); ok { - r1 = rf(userID, channelID, afterTime) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetUsersForSync provides a mock function with given fields: filter func (_m *SharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { ret := _m.Called(filter) @@ -748,24 +718,6 @@ func (_m *SharedChannelStore) UpdateRemoteMembershipCursor(id string, syncTime i return r0 } -// UpdateUserLastMembershipSyncAt provides a mock function with given fields: userID, channelID, remoteID, syncTime -func (_m *SharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - ret := _m.Called(userID, channelID, remoteID, syncTime) - - if len(ret) == 0 { - panic("no return value specified for UpdateUserLastMembershipSyncAt") - } - - var r0 error - if rf, ok := ret.Get(0).(func(string, string, string, int64) error); ok { - r0 = rf(userID, channelID, remoteID, syncTime) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // UpdateUserLastSyncAt provides a mock function with given fields: userID, channelID, remoteID func (_m *SharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { ret := _m.Called(userID, channelID, remoteID) diff --git a/server/channels/store/storetest/mocks/Store.go b/server/channels/store/storetest/mocks/Store.go index f77ba092ec9..e12d1d54782 100644 --- a/server/channels/store/storetest/mocks/Store.go +++ b/server/channels/store/storetest/mocks/Store.go @@ -5,13 +5,16 @@ package mocks import ( - sql "database/sql" - time "time" + mlog "github.com/mattermost/mattermost/server/public/shared/mlog" + mock "github.com/stretchr/testify/mock" model "github.com/mattermost/mattermost/server/public/model" - mlog "github.com/mattermost/mattermost/server/public/shared/mlog" + + sql "database/sql" + store "github.com/mattermost/mattermost/server/v8/channels/store" - mock "github.com/stretchr/testify/mock" + + time "time" ) // Store is an autogenerated mock type for the Store type diff --git a/server/channels/store/storetest/mocks/ThreadStore.go b/server/channels/store/storetest/mocks/ThreadStore.go index 4a3aa04edb0..e2d222da345 100644 --- a/server/channels/store/storetest/mocks/ThreadStore.go +++ b/server/channels/store/storetest/mocks/ThreadStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // ThreadStore is an autogenerated mock type for the ThreadStore type diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index fa0670cda2a..0845e1e23ab 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -8,9 +8,11 @@ import ( context "context" model "github.com/mattermost/mattermost/server/public/model" - request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + request "github.com/mattermost/mattermost/server/public/shared/request" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // UserStore is an autogenerated mock type for the UserStore type diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 439445b25a5..2575e55bc4d 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -3224,6 +3224,22 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star return result, err } +func (s *TimerLayerChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + start := time.Now() + + result, err := s.ChannelMemberHistoryStore.GetMembershipChanges(channelID, since, limit) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetMembershipChanges", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { start := time.Now() @@ -10065,22 +10081,6 @@ func (s *TimerLayerSharedChannelStore) GetSingleUser(userID string, channelID st return result, err } -func (s *TimerLayerSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - start := time.Now() - - result, err := s.SharedChannelStore.GetUserChanges(userID, channelID, afterTime) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUserChanges", success, elapsed) - } - return result, err -} - func (s *TimerLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { start := time.Now() @@ -10289,22 +10289,6 @@ func (s *TimerLayerSharedChannelStore) UpdateRemoteMembershipCursor(id string, s return err } -func (s *TimerLayerSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - start := time.Now() - - err := s.SharedChannelStore.UpdateUserLastMembershipSyncAt(userID, channelID, remoteID, syncTime) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateUserLastMembershipSyncAt", success, elapsed) - } - return err -} - func (s *TimerLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { start := time.Now() diff --git a/server/platform/services/sharedchannel/channelinvite.go b/server/platform/services/sharedchannel/channelinvite.go index f21a393316f..1581d56c700 100644 --- a/server/platform/services/sharedchannel/channelinvite.go +++ b/server/platform/services/sharedchannel/channelinvite.go @@ -135,7 +135,6 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc } curTime := model.GetMillis() - var sharedChannelRemote *model.SharedChannelRemote if existingScr != nil { if existingScr.DeleteAt == 0 && existingScr.IsInviteConfirmed { // the shared channel remote exists and is not @@ -155,7 +154,6 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, sErr)) return } - sharedChannelRemote = existingScr } else { // the shared channel remote doesn't exists, so we create it scr := &model.SharedChannelRemote{ @@ -172,20 +170,13 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, err)) return } - sharedChannelRemote = scr } scs.NotifyChannelChanged(sc.ChannelId) scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("`%s` has been added to channel.", rc.DisplayName)) - // Sync all channel members to the remote now that the remote entry exists - if syncErr := scs.SyncAllChannelMembers(sc.ChannelId, rc.RemoteId, sharedChannelRemote); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync channel members after invite confirmation", - mlog.String("channel_id", sc.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(sc.ChannelId, "") } if rc.IsPlugin() { @@ -326,14 +317,8 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model return fmt.Errorf("cannot restore deleted shared channel remote (channel_id=%s): %w", invite.ChannelId, err) } - // Sync local channel members to the remote after restoring the shared channel - if syncErr := scs.SyncAllChannelMembers(channel.Id, rc.RemoteId, existingScr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync local channel members after restoring shared channel", - mlog.String("channel_id", channel.Id), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(channel.Id, "") } else { creatorID := channel.CreatorId if creatorID == "" { @@ -361,14 +346,8 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model return fmt.Errorf("cannot create shared channel remote (channel_id=%s): %w", invite.ChannelId, err) } - // Sync local channel members to the remote after accepting the invitation - if syncErr := scs.SyncAllChannelMembers(channel.Id, rc.RemoteId, scr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync local channel members after accepting invitation", - mlog.String("channel_id", channel.Id), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(channel.Id, "") } return nil } diff --git a/server/platform/services/sharedchannel/membership.go b/server/platform/services/sharedchannel/membership.go index fc8cb1dd373..0cbe19858b8 100644 --- a/server/platform/services/sharedchannel/membership.go +++ b/server/platform/services/sharedchannel/membership.go @@ -4,14 +4,10 @@ package sharedchannel import ( - "context" - "encoding/json" - "fmt" "time" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" - "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" ) // isChannelMemberSyncEnabled checks if the feature flag is enabled and remote cluster service is available @@ -21,226 +17,25 @@ func (scs *Service) isChannelMemberSyncEnabled() bool { return featureFlagEnabled && remoteClusterService != nil } -// queueMembershipSyncTask creates and queues a task to synchronize channel membership changes -func (scs *Service) queueMembershipSyncTask(channelID, userID, remoteID string, syncMsg *model.SyncMsg, retryMsg *model.SyncMsg) { - task := newSyncTask(channelID, userID, remoteID, syncMsg, retryMsg) +// NotifyMembershipChanged is called when users are added or removed from a shared channel. +// It triggers a sync for the channel so that membership changes are picked up from +// ChannelMemberHistory at sync time, following the same pattern as posts and reactions. +// originRemoteID identifies the remote that initiated the change, so it can be skipped +// during sync to prevent echo-back. +func (scs *Service) NotifyMembershipChanged(channelID string, originRemoteID string) { + if !scs.isChannelMemberSyncEnabled() { + return + } + task := newSyncTask(channelID, "", "", nil, nil) + task.originRemoteID = originRemoteID task.schedule = time.Now().Add(NotifyMinimumDelay) - scs.addTask(task) } -// HandleMembershipChange is called when users are added or removed from a shared channel. -// It creates a task to notify all remote clusters about the membership change. -func (scs *Service) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { - if !scs.isChannelMemberSyncEnabled() { - return - } - - // Create timestamp for consistent usage - changeTime := model.GetMillis() - - // Create membership change info - syncMsg := model.NewSyncMsg(channelID) - syncMsg.MembershipChanges = []*model.MembershipChangeMsg{ - { - ChannelId: channelID, - UserId: userID, - IsAdd: isAdd, - RemoteId: remoteID, // which remote initiated this change - ChangeTime: changeTime, - }, - } - - // Queue the membership change task - scs.queueMembershipSyncTask(channelID, userID, "", syncMsg, nil) -} - -// HandleMembershipBatchChange is called to process a batch of membership changes for a shared channel. -// It creates a task to notify all remote clusters about the batch membership changes. -func (scs *Service) HandleMembershipBatchChange(channelID string, userIDs []string, isAdd bool, remoteID string) { - if !scs.isChannelMemberSyncEnabled() { - return - } - - if len(userIDs) == 0 { - return - } - - // Create timestamp for consistent usage - changeTime := model.GetMillis() - - // Create sync message with membership changes - syncMsg := model.NewSyncMsg(channelID) - syncMsg.MembershipChanges = make([]*model.MembershipChangeMsg, 0, len(userIDs)) - - // Add each user to the batch - for _, userID := range userIDs { - syncMsg.MembershipChanges = append(syncMsg.MembershipChanges, &model.MembershipChangeMsg{ - ChannelId: channelID, - UserId: userID, - IsAdd: isAdd, - RemoteId: remoteID, - ChangeTime: changeTime, - }) - } - - // Queue the batch membership sync task - scs.queueMembershipSyncTask(channelID, "", "", syncMsg, nil) -} - -// SyncAllChannelMembers synchronizes channel members to a specific remote. -// This is typically called when a channel is first shared with a remote cluster. -// If remote is provided, it will be used instead of fetching from the database. -// -// When LastMembersSyncAt is non-zero, only members updated after that cursor are synced (i.e., the delta). -// When LastMembersSyncAt is zero (initial share), all current members are synced. -// -// Limitation: this function only detects membership additions and modifications, not removals. -// Channel member rows are hard-deleted from the database, so members removed while a remote was -// offline cannot be detected by this query. Removals rely on real-time HandleMembershipChange events. -func (scs *Service) SyncAllChannelMembers(channelID string, remoteID string, remote *model.SharedChannelRemote) error { - if !scs.isChannelMemberSyncEnabled() { - return nil - } - - // Verify the channel exists and is shared - if _, err := scs.server.GetStore().SharedChannel().Get(channelID); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Failed to get shared channel", - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - return fmt.Errorf("failed to get shared channel %s: %w", channelID, err) - } - - // Get the remote to ensure it exists (if not provided) - if remote == nil { - var err error - remote, err = scs.server.GetStore().SharedChannel().GetRemoteByIds(channelID, remoteID) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Failed to get remote", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Err(err), - ) - return fmt.Errorf("failed to get remote for channel %s: %w", channelID, err) - } - } - - // Use offset-based pagination to handle channels with many members - // This ensures we don't skip members when multiple members have the same LastUpdateAt timestamp - maxPerPage := scs.GetMemberSyncBatchSize() - var allMembers model.ChannelMembers - lastSyncAt := remote.LastMembersSyncAt - offset := 0 - - // Process members incrementally with offset-based pagination - for { - opts := model.ChannelMembersGetOptions{ - ChannelID: channelID, - UpdatedAfter: lastSyncAt, - Limit: maxPerPage, - Offset: offset, - } - - members, err1 := scs.server.GetStore().Channel().GetMembers(opts) - if err1 != nil { - return fmt.Errorf("failed to get members for channel %s: %w", channelID, err1) - } - - if len(members) == 0 { - break // No more members to process - } - - // Add to our collection - allMembers = append(allMembers, members...) - - // Log progress when processing large channels - if len(allMembers)%1000 == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Processing channel members in batches", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Int("processed_so_far", len(allMembers)), - ) - } - - if len(members) < maxPerPage { - break // Last page - } - - // Move to next page - offset += maxPerPage - } - - if len(allMembers) == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "No members to sync for channel", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - ) - return nil - } - - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Syncing all channel members", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Int("member_count", len(allMembers)), - ) - - // Get batch size from config - batchSize := scs.GetMemberSyncBatchSize() - - // For small channels, queue individual membership changes - if len(allMembers) <= batchSize { - return scs.syncMembersIndividually(channelID, remoteID, allMembers, remote) - } - - // For larger channels, use batch processing - return scs.syncMembersInBatches(channelID, remoteID, allMembers, remote) -} - -// syncMembersIndividually processes each member individually -// This is more efficient for small channels -func (scs *Service) syncMembersIndividually(channelID, remoteID string, members model.ChannelMembers, remote *model.SharedChannelRemote) error { - // Queue individual membership changes for each member - for _, member := range members { - // Queue membership change for this user (isAdd=true) - scs.HandleMembershipChange(channelID, member.UserId, true, "") - } - - return nil -} - -// syncMembersInBatches processes members in batches for greater efficiency -// This is better for channels with many members -func (scs *Service) syncMembersInBatches(channelID, remoteID string, members model.ChannelMembers, remote *model.SharedChannelRemote) error { - // Get batch size from config - batchSize := scs.GetMemberSyncBatchSize() - - for i := 0; i < len(members); i += batchSize { - end := min(i+batchSize, len(members)) - - // Create a batch of members - batchMembers := members[i:end] - - // Extract user IDs from the batch - userIDs := make([]string, len(batchMembers)) - for j, member := range batchMembers { - userIDs[j] = member.UserId - } - - // Use the batch handling function to queue the changes - scs.HandleMembershipBatchChange(channelID, userIDs, true, "") - } - - return nil -} - -// ForceMembershipSyncForRemote syncs channel membership for all channels shared with the -// specified remote. Called when a remote comes back online to catch up on any membership -// changes that occurred while it was offline. -// -// Note: SyncAllChannelMembers uses the LastMembersSyncAt cursor on each SharedChannelRemote -// record, so only the membership delta since the last successful sync is sent, not the full -// member list. See SyncAllChannelMembers for known limitations regarding removed members. +// ForceMembershipSyncForRemote triggers a sync for all channels shared with the specified remote. +// Called when a remote comes back online to catch up on any membership changes that occurred +// while it was offline. The sync pipeline will read ChannelMemberHistory using the +// LastMembersSyncAt cursor to detect both additions and removals. func (scs *Service) ForceMembershipSyncForRemote(rc *model.RemoteCluster) { if !scs.isChannelMemberSyncEnabled() { return @@ -260,191 +55,6 @@ func (scs *Service) ForceMembershipSyncForRemote(rc *model.RemoteCluster) { } for _, scr := range scrs { - if syncErr := scs.SyncAllChannelMembers(scr.ChannelId, rc.RemoteId, scr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync channel members for reconnected remote", - mlog.String("channel_id", scr.ChannelId), - mlog.String("remote", rc.DisplayName), - mlog.String("remoteId", rc.RemoteId), - mlog.Err(syncErr), - ) - } - } -} - -// processMembershipChange processes a channel membership change task. -// It determines which remotes should receive the update and creates tasks for each. -func (scs *Service) processMembershipChange(syncMsg *model.SyncMsg) { - if len(syncMsg.MembershipChanges) == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Invalid membership change task - no membership changes", - mlog.String("channel_id", syncMsg.ChannelId), - ) - return - } - - // Get the shared channel (to verify it exists) - _, err := scs.server.GetStore().SharedChannel().Get(syncMsg.ChannelId) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get shared channel for membership change", - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Int("change_count", len(syncMsg.MembershipChanges)), - mlog.Err(err), - ) - return - } - - // Get all remotes for this channel - remotes, err := scs.server.GetStore().SharedChannel().GetRemotes(0, 999999, model.SharedChannelRemoteFilterOpts{ - ChannelId: syncMsg.ChannelId, - }) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get shared channel remotes for membership change", - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - return - } - - // Always use batch processing for consistency (works for single or multiple changes) - scs.syncMembershipBatchToRemotes(syncMsg, remotes) -} - -// syncMembershipBatchToRemotes synchronizes membership changes (single or batch) with remote clusters. -func (scs *Service) syncMembershipBatchToRemotes(syncMsg *model.SyncMsg, remotes []*model.SharedChannelRemote) { - if len(syncMsg.MembershipChanges) == 0 { - return - } - - // Get the initiating remote ID from the first change (all should be the same) - initiatingRemoteId := "" - if len(syncMsg.MembershipChanges) > 0 { - initiatingRemoteId = syncMsg.MembershipChanges[0].RemoteId - } - - // Send to all remotes except the one that initiated this change - for _, remote := range remotes { - // Skip the remote that initiated this change to prevent loops - if remote.RemoteId == initiatingRemoteId { - continue - } - - // Get the remote cluster - rc, err := scs.server.GetStore().RemoteCluster().Get(remote.RemoteId, false) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get remote cluster for batch membership sync", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - continue - } - - // Create a copy of the sync message to potentially add user profiles - enrichedSyncMsg := &model.SyncMsg{ - Id: syncMsg.Id, - ChannelId: syncMsg.ChannelId, - MembershipChanges: syncMsg.MembershipChanges, - Users: make(map[string]*model.User), - } - - // Add user profiles for all users being added - for _, change := range syncMsg.MembershipChanges { - if change.IsAdd { - user, pErr := scs.server.GetStore().User().Get(context.Background(), change.UserId) - if pErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Failed to get user for batch membership sync", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.String("remote_id", remote.RemoteId), - mlog.Err(pErr), - ) - continue - } - - // Check if user profile needs to be synced - doSync, _, sErr := scs.shouldUserSync(user, syncMsg.ChannelId, rc) - if sErr == nil && doSync { - enrichedSyncMsg.Users[user.Id] = sanitizeUserForSyncSafe(user) - } - } - } - - // Send message using the existing remote cluster framework - payload, err := json.Marshal(enrichedSyncMsg) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to marshal batch membership message", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - continue - } - - msg := model.RemoteClusterMsg{ - Id: model.NewId(), - Topic: TopicChannelMembership, - CreateAt: model.GetMillis(), - Payload: payload, - } - - ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) - defer cancel() - - // Define a callback function - callback := func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Error sending batch membership changes to remote", - mlog.String("remote", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Int("change_count", len(syncMsg.MembershipChanges)), - mlog.Err(err), - ) - return - } - - if resp != nil && resp.Err != "" { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Remote error when processing batch membership changes", - mlog.String("remote", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.String("remote_error", resp.Err), - ) - return - } - - // Update sync timestamps - for _, change := range syncMsg.MembershipChanges { - if err := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, remote.RemoteId, change.ChangeTime); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user membership sync timestamp in batch", - mlog.String("user_id", change.UserId), - mlog.Err(err), - ) - } - } - - // Update the cursor with the latest change time - var maxChangeTime int64 - for _, change := range syncMsg.MembershipChanges { - if change.ChangeTime > maxChangeTime { - maxChangeTime = change.ChangeTime - } - } - - if err := scs.updateMembershipSyncCursor(syncMsg.ChannelId, remote.RemoteId, maxChangeTime); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update membership sync cursor for batch", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - } - } - - err = scs.server.GetRemoteClusterService().SendMsg(ctx, msg, rc, callback) - - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to send batch membership changes to remote", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - } + scs.NotifyChannelChanged(scr.ChannelId) } } diff --git a/server/platform/services/sharedchannel/membership_recv.go b/server/platform/services/sharedchannel/membership_recv.go index 60d50605360..4ea92537ace 100644 --- a/server/platform/services/sharedchannel/membership_recv.go +++ b/server/platform/services/sharedchannel/membership_recv.go @@ -5,7 +5,6 @@ package sharedchannel import ( "fmt" - "strings" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" @@ -13,36 +12,21 @@ import ( "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" ) -// checkMembershipConflict checks if there are newer changes that would conflict with this one -// Returns true if this change should be skipped due to a conflict -func (scs *Service) checkMembershipConflict(userID, channelID string, changeTime int64) (bool, error) { - conflicts, err := scs.server.GetStore().SharedChannel().GetUserChanges(userID, channelID, changeTime) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to check for membership change conflicts", - mlog.String("user_id", userID), - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - return false, err - } +const ( + // Error IDs returned by the app layer that indicate an idempotent no-op. + errIDAddUserToChannelFailed = "api.channel.add_user.to.channel.failed.app_error" + errIDSaveMemberExists = "app.channel.save_member.exists.app_error" + errIDGetChannelMemberMissing = "app.channel.get_member.missing.app_error" +) - // If there are conflicting operations, the latest one wins - for _, conflict := range conflicts { - if conflict.LastMembershipSyncAt > changeTime { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Ignoring older membership change due to conflict", - mlog.String("user_id", userID), - mlog.String("channel_id", channelID), - mlog.Int("change_time", int(changeTime)), - mlog.Int("conflicting_time", int(conflict.LastMembershipSyncAt)), - ) - return true, nil - } - } - - return false, nil -} - -// onReceiveMembershipChanges processes channel membership changes from a remote cluster +// onReceiveMembershipChanges processes channel membership changes from a remote cluster. +// In the new model, the sender derives the authoritative net state from ChannelMemberHistory. +// Both processMemberAdd and processMemberRemove are idempotent: +// - processMemberAdd ignores "already added" errors +// - processMemberRemove ignores "not found" errors +// +// Out-of-order messages resolve naturally: if an old "add" arrives after a newer "remove", +// the sender's next sync cycle will send a corrective "remove" because the history shows the user left. func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { // Check if feature flag is enabled if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { @@ -65,16 +49,8 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model return fmt.Errorf("cannot get shared channel for membership changes: %w", err) } - // Calculate the maximum ChangeTime from all changes in the batch - var maxChangeTime int64 - for _, change := range syncMsg.MembershipChanges { - if change.ChangeTime > maxChangeTime { - maxChangeTime = change.ChangeTime - } - } - // Process each change - var successCount, skipCount, failCount int + var failCount int for _, change := range syncMsg.MembershipChanges { if change.ChannelId != syncMsg.ChannelId { @@ -87,13 +63,6 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model continue } - // Check for conflicts - shouldSkip, _ := scs.checkMembershipConflict(change.UserId, change.ChannelId, change.ChangeTime) - if shouldSkip { - skipCount++ - continue - } - // Process the membership change based on whether it's an add or remove var processErr error if change.IsAdd { @@ -102,14 +71,14 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model mlog.String("channel_id", change.ChannelId), mlog.String("remote_id", rc.RemoteId), ) - processErr = scs.processMemberAdd(change, channel, rc, maxChangeTime, syncMsg) + processErr = scs.processMemberAdd(change, channel, rc, syncMsg) } else { scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Removing user from channel from remote cluster", mlog.String("user_id", change.UserId), mlog.String("channel_id", change.ChannelId), mlog.String("remote_id", rc.RemoteId), ) - processErr = scs.processMemberRemove(change, rc, maxChangeTime) + processErr = scs.processMemberRemove(change, rc) } if processErr != nil { @@ -123,15 +92,21 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model failCount++ continue } + } - successCount++ + if failCount > 0 { + scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Some membership changes failed", + mlog.String("channel_id", syncMsg.ChannelId), + mlog.Int("total", len(syncMsg.MembershipChanges)), + mlog.Int("failed", failCount), + ) } return nil } -// processMemberAdd handles adding a user to a channel as part of batch processing -func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel *model.Channel, rc *model.RemoteCluster, maxChangeTime int64, syncMsg *model.SyncMsg) error { +// processMemberAdd handles adding a user to a channel +func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel *model.Channel, rc *model.RemoteCluster, syncMsg *model.SyncMsg) error { rctx := request.EmptyContext(scs.server.Log()) var user *model.User var err error @@ -166,30 +141,18 @@ func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel // Skip team member check (true) since we already handled team membership above _, appErr := scs.app.AddUserToChannel(rctx, user, channel, true) if appErr != nil { - // Skip "already added" errors - if appErr.Error() != "api.channel.add_user.to_channel.failed.app_error" && - !strings.Contains(appErr.Error(), "channel_member_exists") { + // Skip "already added" errors — idempotent + if appErr.Id != errIDAddUserToChannelFailed && + appErr.Id != errIDSaveMemberExists { return fmt.Errorf("cannot add user to channel: %w", appErr) } - // User is already in the channel, which is fine - } - - // Update the sync status - if syncErr := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, rc.RemoteId, maxChangeTime); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user LastMembershipSyncAt after batch member add", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", change.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - // Continue despite the error - this is not critical } return nil } -// processMemberRemove handles removing a user from a channel as part of batch processing -func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *model.RemoteCluster, maxChangeTime int64) error { +// processMemberRemove handles removing a user from a channel +func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *model.RemoteCluster) error { // Get channel so we can use app layer methods properly channel, err := scs.server.GetStore().Channel().Get(change.ChannelId, true) if err != nil { @@ -198,7 +161,7 @@ func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *m mlog.String("user_id", change.UserId), mlog.Err(err), ) - // Continue anyway to update sync status - the channel might be deleted + return nil // channel might be deleted, nothing to do } rctx := request.EmptyContext(scs.server.Log()) @@ -210,32 +173,14 @@ func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *m return fmt.Errorf("membership remove sync failed: %w", ErrRemoteIDMismatch) } - // Use the app layer's remove user method if channel still exists - if channel != nil { - appErr := scs.app.RemoveUserFromChannel(rctx, change.UserId, "", channel) - if appErr != nil { - // Ignore "not found" errors - the user might already be removed - if !strings.Contains(appErr.Error(), "store.sql_channel.remove_member.missing.app_error") { - scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Error removing user from channel", - mlog.String("channel_id", change.ChannelId), - mlog.String("user_id", change.UserId), - mlog.Err(appErr), - ) - // Continue anyway to update sync status - don't return error here - // to ensure sync status still gets updated - } + // Use the app layer's remove user method + appErr := scs.app.RemoveUserFromChannel(rctx, change.UserId, "", channel) + if appErr != nil { + // Ignore "not found" errors - the user might already be removed + if appErr.Id == errIDGetChannelMemberMissing { + return nil } - } - - // Update the sync status - if syncErr := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, rc.RemoteId, maxChangeTime); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user LastMembershipSyncAt after batch member remove", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", change.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - // Continue despite the error - this is not critical + return fmt.Errorf("cannot remove user from channel: %w", appErr) } return nil diff --git a/server/platform/services/sharedchannel/membership_recv_test.go b/server/platform/services/sharedchannel/membership_recv_test.go index a0478ac6160..cdfedeb2e2b 100644 --- a/server/platform/services/sharedchannel/membership_recv_test.go +++ b/server/platform/services/sharedchannel/membership_recv_test.go @@ -4,6 +4,7 @@ package sharedchannel import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -73,9 +74,6 @@ func TestOnReceiveMembershipChanges_ChannelIdMismatch(t *testing.T) { err := scs.onReceiveMembershipChanges(syncMsg, rc, nil) require.NoError(t, err) // function returns nil even on individual failures - - // The conflict check should never have been called since the mismatch was caught first - mockSharedChannelStore.AssertNotCalled(t, "GetUserChanges", mock.Anything, mock.Anything, mock.Anything) } func TestProcessMemberAdd_RejectsLocalUser(t *testing.T) { @@ -94,9 +92,6 @@ func TestProcessMemberAdd_RejectsLocalUser(t *testing.T) { localUser := &model.User{Id: localUserId} mockUserStore.On("Get", mockTypeContext, localUserId).Return(localUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", localUserId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -132,9 +127,6 @@ func TestProcessMemberAdd_RejectsOtherRemoteUser(t *testing.T) { otherRemoteUser := &model.User{Id: userId, RemoteId: &otherRemoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(otherRemoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -169,12 +161,8 @@ func TestProcessMemberAdd_AllowsOwnRemoteUser(t *testing.T) { remoteUser := &model.User{Id: userId, RemoteId: &remoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - // Expect the add to proceed mockApp.On("AddUserToChannel", mockTypeReqContext, mockTypeUser, mockTypeChannel, true).Return(&model.ChannelMember{}, nil) - mockSharedChannelStore.On("UpdateUserLastMembershipSyncAt", userId, channelId, remoteId, int64(1000)).Return(nil) syncMsg := &model.SyncMsg{ ChannelId: channelId, @@ -210,9 +198,6 @@ func TestProcessMemberRemove_RejectsLocalUser(t *testing.T) { localUser := &model.User{Id: localUserId} mockUserStore.On("Get", mockTypeContext, localUserId).Return(localUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", localUserId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -248,9 +233,6 @@ func TestProcessMemberRemove_RejectsOtherRemoteUser(t *testing.T) { otherRemoteUser := &model.User{Id: userId, RemoteId: &otherRemoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(otherRemoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -285,12 +267,8 @@ func TestProcessMemberRemove_AllowsOwnRemoteUser(t *testing.T) { remoteUser := &model.User{Id: userId, RemoteId: &remoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - // Expect the remove to proceed mockApp.On("RemoveUserFromChannel", mockTypeReqContext, userId, "", channel).Return(nil) - mockSharedChannelStore.On("UpdateUserLastMembershipSyncAt", userId, channelId, remoteId, int64(1000)).Return(nil) syncMsg := &model.SyncMsg{ ChannelId: channelId, @@ -333,7 +311,7 @@ func TestProcessMemberAdd_RejectsLocalUser_ErrorMessage(t *testing.T) { ChannelId: channelId, } - err := scs.processMemberAdd(change, channel, rc, 1000, syncMsg) + err := scs.processMemberAdd(change, channel, rc, syncMsg) require.Error(t, err) assert.ErrorIs(t, err, ErrRemoteIDMismatch) assert.Contains(t, err.Error(), "membership add sync failed") @@ -360,8 +338,67 @@ func TestProcessMemberRemove_RejectsLocalUser_ErrorMessage(t *testing.T) { ChangeTime: 1000, } - err := scs.processMemberRemove(change, rc, 1000) + err := scs.processMemberRemove(change, rc) require.Error(t, err) assert.ErrorIs(t, err, ErrRemoteIDMismatch) assert.Contains(t, err.Error(), "membership remove sync failed") } + +func TestProcessMemberAdd_IdempotentForExistingMember(t *testing.T) { + scs, _, mockApp, _, _, _, mockUserStore := setupMembershipTest(t) + + channelId := model.NewId() + remoteId := model.NewId() + userId := model.NewId() + rc := &model.RemoteCluster{RemoteId: remoteId} + channel := &model.Channel{Id: channelId, Type: model.ChannelTypeOpen} + + // User belongs to the sending remote + remoteUser := &model.User{Id: userId, RemoteId: &remoteId} + mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) + + // Simulate "already added" error from AddUserToChannel + alreadyAddedErr := model.NewAppError("AddUserToChannel", errIDSaveMemberExists, nil, "", http.StatusBadRequest) + mockApp.On("AddUserToChannel", mockTypeReqContext, mockTypeUser, mockTypeChannel, true).Return(nil, alreadyAddedErr) + + syncMsg := &model.SyncMsg{ChannelId: channelId} + change := &model.MembershipChangeMsg{ + ChannelId: channelId, + UserId: userId, + IsAdd: true, + ChangeTime: 1000, + } + + err := scs.processMemberAdd(change, channel, rc, syncMsg) + require.NoError(t, err, "should succeed for already-added user (idempotent)") +} + +func TestProcessMemberRemove_IdempotentForNonMember(t *testing.T) { + scs, _, mockApp, _, _, mockChannelStore, mockUserStore := setupMembershipTest(t) + + channelId := model.NewId() + remoteId := model.NewId() + userId := model.NewId() + rc := &model.RemoteCluster{RemoteId: remoteId} + channel := &model.Channel{Id: channelId, Type: model.ChannelTypeOpen} + + mockChannelStore.On("Get", channelId, true).Return(channel, nil) + + // User belongs to the sending remote + remoteUser := &model.User{Id: userId, RemoteId: &remoteId} + mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) + + // Simulate "not found" error from RemoveUserFromChannel + notFoundErr := model.NewAppError("RemoveUserFromChannel", errIDGetChannelMemberMissing, nil, "", http.StatusNotFound) + mockApp.On("RemoveUserFromChannel", mockTypeReqContext, userId, "", channel).Return(notFoundErr) + + change := &model.MembershipChangeMsg{ + ChannelId: channelId, + UserId: userId, + IsAdd: false, + ChangeTime: 1000, + } + + err := scs.processMemberRemove(change, rc) + require.NoError(t, err, "should succeed for already-removed user (idempotent)") +} diff --git a/server/platform/services/sharedchannel/membership_send_test.go b/server/platform/services/sharedchannel/membership_send_test.go new file mode 100644 index 00000000000..9e683787c30 --- /dev/null +++ b/server/platform/services/sharedchannel/membership_send_test.go @@ -0,0 +1,362 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest/mock" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks" +) + +func setupSendTest(t *testing.T, enableMemberSync bool) (*Service, *mocks.Store, *mocks.ChannelMemberHistoryStore, *mocks.SharedChannelStore, *mocks.UserStore, *mocks.RemoteClusterStore) { + t.Helper() + + mockServer := &MockServerIface{} + logger := mlog.CreateConsoleTestLogger(t) + mockServer.On("Log").Return(logger) + mockServer.On("GetMetrics").Return(nil) + mockServer.On("GetRemoteClusterService").Return(nil) + + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + mockCMHStore := &mocks.ChannelMemberHistoryStore{} + mockSharedChannelStore := &mocks.SharedChannelStore{} + mockUserStore := &mocks.UserStore{} + mockRCStore := &mocks.RemoteClusterStore{} + + mockStore.On("ChannelMemberHistory").Return(mockCMHStore) + mockStore.On("SharedChannel").Return(mockSharedChannelStore) + mockStore.On("User").Return(mockUserStore) + mockStore.On("RemoteCluster").Return(mockRCStore) + mockServer.On("GetStore").Return(mockStore) + + mockConfig := model.Config{} + mockConfig.SetDefaults() + mockConfig.FeatureFlags.EnableSharedChannelsMemberSync = enableMemberSync + mockServer.On("Config").Return(&mockConfig) + + return scs, mockStore, mockCMHStore, mockSharedChannelStore, mockUserStore, mockRCStore +} + +func TestFetchMembershipsForSync_FeatureFlagDisabled(t *testing.T) { + scs, _, _, _, _, _ := setupSendTest(t, false) + + sd := &syncData{ + task: syncTask{channelID: model.NewId()}, + rc: &model.RemoteCluster{RemoteId: model.NewId()}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.Empty(t, sd.membershipChanges, "should not fetch when feature flag disabled") +} + +func TestFetchMembershipsForSync_NoChanges(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: model.NewId()}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.Empty(t, sd.membershipChanges) + assert.Equal(t, int64(0), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_DeduplicatesJoinLeaveRejoin(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + userID := model.NewId() + + // User joined at 1000, left at 2000, rejoined at 3000 + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: userID, JoinTime: 1000, LeaveTime: model.NewPointer(int64(2000))}, + {ChannelId: channelID, UserId: userID, JoinTime: 3000, LeaveTime: nil}, + }, nil) + + // User profile for the add — local user (no RemoteId) so shouldUserSync includes it + user := &model.User{Id: userID} + mockUserStore.On("Get", mock.Anything, userID).Return(user, nil) + + // shouldUserSync needs SharedChannelUser lookup + mockSCStore.On("GetSingleUser", userID, channelID, remoteID). + Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.AnythingOfType("*model.SharedChannelUser")). + Return(&model.SharedChannelUser{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + // Should deduplicate to 1 entry: user is currently a member (rejoin at 3000 > leave at 2000) + require.Len(t, sd.membershipChanges, 1) + assert.Equal(t, userID, sd.membershipChanges[0].UserId) + assert.True(t, sd.membershipChanges[0].IsAdd, "user rejoined so should be IsAdd=true") + assert.Equal(t, int64(3000), sd.membershipChanges[0].ChangeTime) + assert.Equal(t, int64(3000), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_DeduplicatesJoinThenLeave(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + userID := model.NewId() + + // User joined at 1000, then left at 2000 + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: userID, JoinTime: 1000, LeaveTime: model.NewPointer(int64(2000))}, + }, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + require.Len(t, sd.membershipChanges, 1) + assert.False(t, sd.membershipChanges[0].IsAdd, "user left so should be IsAdd=false") + assert.Equal(t, int64(2000), sd.membershipChanges[0].ChangeTime) + assert.Equal(t, int64(2000), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_MultipleUsers(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + user1 := model.NewId() + user2 := model.NewId() + user3 := model.NewId() + + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: user1, JoinTime: 1000, LeaveTime: nil}, // joined, still member + {ChannelId: channelID, UserId: user2, JoinTime: 2000, LeaveTime: model.NewPointer(int64(3000))}, // joined then left + {ChannelId: channelID, UserId: user3, JoinTime: 4000, LeaveTime: nil}, // joined, still member + }, nil) + + // User profiles for adds (user1 and user3) — users are local (no RemoteId) + // so shouldUserSync won't filter them out for the target remote + for _, uid := range []string{user1, user3} { + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + assert.Len(t, sd.membershipChanges, 3) + + // Build a map for easy lookup + changeMap := make(map[string]*model.MembershipChangeMsg) + for _, mc := range sd.membershipChanges { + changeMap[mc.UserId] = mc + } + + assert.True(t, changeMap[user1].IsAdd, "user1 is still a member") + assert.False(t, changeMap[user2].IsAdd, "user2 left") + assert.True(t, changeMap[user3].IsAdd, "user3 is still a member") + + // Max cursor should be 4000 (user3 join) + assert.Equal(t, int64(4000), sd.resultNextMembershipCursor) + + // Only user1 and user3 should be in the users map (adds only) + assert.Contains(t, sd.users, user1) + assert.NotContains(t, sd.users, user2, "user2 left, should not be in users map") + assert.Contains(t, sd.users, user3) +} + +func TestFetchMembershipsForSync_SetsRepeatWhenLimitHit(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + + // Return exactly as many results as the limit (batch size) + batchSize := scs.GetMemberSyncBatchSize() + histories := make([]*model.ChannelMemberHistory, batchSize) + for i := range batchSize { + uid := model.NewId() + histories[i] = &model.ChannelMemberHistory{ + ChannelId: channelID, + UserId: uid, + JoinTime: int64(1000 + i), + LeaveTime: nil, + } + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), batchSize). + Return(histories, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.True(t, sd.resultRepeat, "should set resultRepeat when limit is hit") +} + +func TestFetchMembershipsForSync_CursorFromSCR(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + + // Verify that LastMembersSyncAt from SCR is used as the cursor + mockCMHStore.On("GetMembershipChanges", channelID, int64(5000), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 5000}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + mockCMHStore.AssertCalled(t, "GetMembershipChanges", channelID, int64(5000), mock.AnythingOfType("int")) +} + +func TestFetchMembershipsForSync_RepeatedTimestampsAtBoundary(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + batchSize := scs.GetMemberSyncBatchSize() + + // Create batch where the last entries share the same JoinTime (boundary timestamp) + boundaryTime := int64(5000) + histories := make([]*model.ChannelMemberHistory, batchSize) + for i := range batchSize { + uid := model.NewId() + joinTime := int64(1000 + i) + // Last 3 entries share the same timestamp + if i >= batchSize-3 { + joinTime = boundaryTime + } + histories[i] = &model.ChannelMemberHistory{ + ChannelId: channelID, + UserId: uid, + JoinTime: joinTime, + LeaveTime: nil, + } + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + // First fetch returns full batch (hits limit) + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), batchSize). + Return(histories, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.True(t, sd.resultRepeat, "should signal more data when limit hit") + assert.Len(t, sd.membershipChanges, batchSize, "all entries in batch should be processed") + assert.Equal(t, boundaryTime, sd.resultNextMembershipCursor, "cursor should be at boundary timestamp") + + // Second fetch uses the boundary timestamp as cursor (GtOrEq means events + // AT boundaryTime will be re-fetched, but deduplication handles this) + extraUser := model.NewId() + mockCMHStore.On("GetMembershipChanges", channelID, boundaryTime, batchSize). + Return([]*model.ChannelMemberHistory{ + // Re-fetched rows at boundary (already processed — dedup will handle) + histories[batchSize-3], + histories[batchSize-2], + histories[batchSize-1], + // New row beyond the boundary + {ChannelId: channelID, UserId: extraUser, JoinTime: boundaryTime + 1, LeaveTime: nil}, + }, nil) + + extraUserObj := &model.User{Id: extraUser} + mockUserStore.On("Get", mock.Anything, extraUser).Return(extraUserObj, nil) + mockSCStore.On("GetSingleUser", extraUser, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == extraUser })). + Return(&model.SharedChannelUser{}, nil) + + sd2 := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: boundaryTime}, + users: make(map[string]*model.User), + } + + err = scs.fetchMembershipsForSync(sd2) + require.NoError(t, err) + // Should have 4 entries (3 re-fetched + 1 new) — dedup at the caller level + // handles duplicate user IDs across batches + assert.Len(t, sd2.membershipChanges, 4, "should include re-fetched boundary rows and new row") + assert.Equal(t, boundaryTime+1, sd2.resultNextMembershipCursor, "cursor should advance past boundary") +} + +// notFoundError implements the errNotFound interface used by the sharedchannel package +type notFoundError struct{} + +func (e *notFoundError) Error() string { return "not found" } +func (e *notFoundError) IsErrNotFound() bool { return true } diff --git a/server/platform/services/sharedchannel/sync_recv.go b/server/platform/services/sharedchannel/sync_recv.go index a182eb34e2d..2ab2b90dadf 100644 --- a/server/platform/services/sharedchannel/sync_recv.go +++ b/server/platform/services/sharedchannel/sync_recv.go @@ -87,6 +87,7 @@ func (scs *Service) processSyncMessage(rctx request.CTX, syncMsg *model.SyncMsg, PostErrors: make([]string, 0), ReactionErrors: make([]string, 0), AcknowledgementErrors: make([]string, 0), + MembershipErrors: make([]string, 0), } // Check if feature flag is enabled for membership changes @@ -281,6 +282,7 @@ func (scs *Service) processSyncMessage(rctx request.CTX, syncMsg *model.SyncMsg, mlog.Int("change_count", len(syncMsg.MembershipChanges)), mlog.Err(err), ) + syncResp.MembershipErrors = append(syncResp.MembershipErrors, err.Error()) // Don't fail the entire sync if membership changes fail } } diff --git a/server/platform/services/sharedchannel/sync_send.go b/server/platform/services/sharedchannel/sync_send.go index 6b7496a7f60..602c40e23e2 100644 --- a/server/platform/services/sharedchannel/sync_send.go +++ b/server/platform/services/sharedchannel/sync_send.go @@ -29,6 +29,9 @@ type syncTask struct { retryCount int retryMsg *model.SyncMsg schedule time.Time + // originRemoteID is the remote that initiated this change; it will be + // skipped when syncing to prevent echo-back. + originRemoteID string } func newSyncTask(channelID, userID string, remoteID string, existingMsg, retryMsg *model.SyncMsg) syncTask { @@ -250,6 +253,17 @@ func (scs *Service) addTask(task syncTask) { // if the task was already scheduled, we only update the // existingMsg in case there is new information originalTask.existingMsg = task.existingMsg + + // originRemoteID identifies which remote initiated a change so processTask + // can skip sending back to that remote. When multiple events merge within + // the NotifyMinimumDelay window we can only safely skip a remote if every + // merged event came from that same remote. If the origins differ (e.g. + // remote-A join + remote-B join, or remote join + local join) we must clear + // originRemoteID so the sync fans out to all remotes. The receiver is + // idempotent, so the worst case is a redundant sync to the originating remote. + if task.originRemoteID != originalTask.originRemoteID { + originalTask.originRemoteID = "" + } scs.tasks[task.id] = originalTask } else { scs.tasks[task.id] = task @@ -375,16 +389,6 @@ func (scs *Service) removeOldestTask() (syncTask, bool, time.Duration) { // processTask updates one or more remote clusters with any new channel content. func (scs *Service) processTask(task syncTask) error { - // Check if this is a membership change task - if task.existingMsg != nil && len(task.existingMsg.MembershipChanges) > 0 { - // Check if feature flag is enabled - if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { - return nil - } - scs.processMembershipChange(task.existingMsg) - return nil - } - // map is used to ensure remotes don't get sync'd twice, such as when // they have the autoinvited flag and have explicitly subscribed to a channel. remotesMap := make(map[string]*model.RemoteCluster) @@ -399,6 +403,10 @@ func (scs *Service) processTask(task syncTask) error { return err } for _, r := range remotes { + // Skip the remote that originated this membership change + if task.originRemoteID != "" && r.RemoteId == task.originRemoteID { + continue + } remotesMap[r.RemoteId] = r } @@ -412,6 +420,10 @@ func (scs *Service) processTask(task syncTask) error { return err } for _, r := range remotesAutoInvited { + // Skip the remote that originated this membership change + if task.originRemoteID != "" && r.RemoteId == task.originRemoteID { + continue + } remotesMap[r.RemoteId] = r } } else { diff --git a/server/platform/services/sharedchannel/sync_send_remote.go b/server/platform/services/sharedchannel/sync_send_remote.go index f28479e121c..9cf04920c2d 100644 --- a/server/platform/services/sharedchannel/sync_send_remote.go +++ b/server/platform/services/sharedchannel/sync_send_remote.go @@ -41,6 +41,9 @@ type syncData struct { attachments []attachment mentionTransforms map[string]string + membershipChanges []*model.MembershipChangeMsg + resultNextMembershipCursor int64 + resultRepeat bool resultNextCursor model.GetPostsSinceForSyncCursor GlobalUserSyncLastTimestamp int64 @@ -62,7 +65,7 @@ func newSyncData(task syncTask, rc *model.RemoteCluster, scr *model.SharedChanne } func (sd *syncData) isEmpty() bool { - return len(sd.users) == 0 && len(sd.profileImages) == 0 && len(sd.posts) == 0 && len(sd.reactions) == 0 && len(sd.acknowledgements) == 0 && len(sd.attachments) == 0 + return len(sd.users) == 0 && len(sd.profileImages) == 0 && len(sd.posts) == 0 && len(sd.reactions) == 0 && len(sd.acknowledgements) == 0 && len(sd.attachments) == 0 && len(sd.membershipChanges) == 0 } func (sd *syncData) isCursorChanged() bool { @@ -171,6 +174,11 @@ func (scs *Service) syncForRemote(task syncTask, rc *model.RemoteCluster) error return fmt.Errorf("cannot fetch posts for sync %v: %w", sd, err) } + // fetch membership changes from ChannelMemberHistory + if err := scs.fetchMembershipsForSync(sd); err != nil { + return fmt.Errorf("cannot fetch memberships for sync %v: %w", sd, err) + } + if !rc.IsOnline() { if len(sd.posts) != 0 { scs.notifyRemoteOffline(sd.posts, rc) @@ -342,6 +350,158 @@ func (scs *Service) fetchPostsForSync(sd *syncData) error { return nil } +// fetchMembershipsForSync populates the sync data with membership changes from ChannelMemberHistory. +func (scs *Service) fetchMembershipsForSync(sd *syncData) error { + if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { + return nil + } + + start := time.Now() + defer func() { + if metrics := scs.server.GetMetrics(); metrics != nil { + metrics.ObserveSharedChannelsSyncCollectionStepDuration(sd.rc.RemoteId, "Memberships", time.Since(start).Seconds()) + } + }() + + cursor := sd.scr.LastMembersSyncAt + limit := scs.GetMemberSyncBatchSize() + + histories, err := scs.server.GetStore().ChannelMemberHistory().GetMembershipChanges(sd.task.channelID, cursor, limit) + if err != nil { + return fmt.Errorf("could not fetch membership changes for sync: %w", err) + } + + if len(histories) == 0 { + return nil + } + + // Deduplicate by user — for users with multiple events (join/leave/rejoin cycles), + // resolve to the most recent event to determine current state. + type userState struct { + isAdd bool + eventTime int64 // max of JoinTime/LeaveTime for this user + } + byUser := make(map[string]*userState) + + var maxCursor int64 + for _, h := range histories { + // Track the max event time for the next cursor + if h.JoinTime > maxCursor { + maxCursor = h.JoinTime + } + if h.LeaveTime != nil && *h.LeaveTime > maxCursor { + maxCursor = *h.LeaveTime + } + + // Determine this event's effective time and whether it represents a join or leave + var eventTime int64 + var isAdd bool + if h.LeaveTime == nil || h.JoinTime > *h.LeaveTime { + // User is currently a member (no leave, or rejoined after leaving) + isAdd = true + eventTime = h.JoinTime + } else { + // User left + isAdd = false + eventTime = *h.LeaveTime + } + + // Keep only the most recent state per user + if existing, ok := byUser[h.UserId]; !ok || eventTime > existing.eventTime { + byUser[h.UserId] = &userState{isAdd: isAdd, eventTime: eventTime} + } + } + + // Build MembershipChangeMsg entries + for userID, state := range byUser { + if state.isAdd { + user, userErr := scs.server.GetStore().User().Get(context.Background(), userID) + if userErr != nil { + scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Failed to get user for membership sync", + mlog.String("user_id", userID), + mlog.String("channel_id", sd.task.channelID), + mlog.Err(userErr), + ) + continue + } + + sd.membershipChanges = append(sd.membershipChanges, &model.MembershipChangeMsg{ + ChannelId: sd.task.channelID, + UserId: userID, + IsAdd: true, + ChangeTime: state.eventTime, + }) + + doSync, _, syncErr := scs.shouldUserSync(user, sd.task.channelID, sd.rc) + if syncErr == nil && doSync { + sd.users[user.Id] = user + } + } else { + sd.membershipChanges = append(sd.membershipChanges, &model.MembershipChangeMsg{ + ChannelId: sd.task.channelID, + UserId: userID, + IsAdd: false, + ChangeTime: state.eventTime, + }) + } + } + + sd.resultNextMembershipCursor = maxCursor + + // If we hit the limit, there may be more data to fetch + if len(histories) >= limit { + sd.resultRepeat = true + } + + return nil +} + +// sendMembershipSyncData sends the collected membership changes to the remote cluster. +func (scs *Service) sendMembershipSyncData(sd *syncData) error { + start := time.Now() + defer func() { + if metrics := scs.server.GetMetrics(); metrics != nil { + metrics.ObserveSharedChannelsSyncSendStepDuration(sd.rc.RemoteId, "Memberships", time.Since(start).Seconds()) + } + }() + + msg := model.NewSyncMsg(sd.task.channelID) + msg.MembershipChanges = sd.membershipChanges + + // Include only user profiles for users referenced by membership adds, + // not the full sd.users map which may contain post authors and other users. + memberUsers := make(map[string]*model.User) + for _, mc := range sd.membershipChanges { + if mc.IsAdd { + if u, ok := sd.users[mc.UserId]; ok { + memberUsers[mc.UserId] = u + } + } + } + msg.Users = memberUsers + + return scs.sendSyncMsgToRemote(msg, sd.rc, func(syncResp model.SyncResponse, errResp error) { + if len(syncResp.MembershipErrors) != 0 { + scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Response indicates error for membership(s) sync", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Array("membership_errors", syncResp.MembershipErrors), + ) + } + + // Update the membership cursor on success + if errResp == nil && sd.resultNextMembershipCursor > 0 { + if err := scs.updateMembershipSyncCursor(sd.task.channelID, sd.rc.RemoteId, sd.resultNextMembershipCursor); err != nil { + scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update membership sync cursor", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Err(err), + ) + } + } + }) +} + func appendPosts(dest []*model.Post, posts []*model.Post, postStore store.PostStore, timestamp int64, logger mlog.LoggerIFace) []*model.Post { // Append the posts individually, checking for root posts that might appear later in the list. // This is due to the UpdateAt collision handling algorithm where the order of posts is not based @@ -576,7 +736,7 @@ func (scs *Service) filterPostsForSync(sd *syncData) { // sendSyncData sends all the collected users, posts, reactions, images, and attachments to the // remote cluster. -// The order of items sent is important: users -> attachments -> posts -> reactions -> profile images +// The order of items sent is important: users -> memberships -> attachments -> posts -> reactions -> profile images func (scs *Service) sendSyncData(sd *syncData) error { start := time.Now() defer func() { @@ -595,6 +755,13 @@ func (scs *Service) sendSyncData(sd *syncData) error { } } + // send membership changes (after users so the remote has profiles for added members) + if len(sd.membershipChanges) != 0 { + if err := scs.sendMembershipSyncData(sd); err != nil { + merr.Append(fmt.Errorf("cannot send membership sync data: %w", err)) + } + } + // send attachments if len(sd.attachments) != 0 { scs.sendAttachmentSyncData(sd) diff --git a/server/platform/services/sharedchannel/sync_send_test.go b/server/platform/services/sharedchannel/sync_send_test.go new file mode 100644 index 00000000000..76016881faf --- /dev/null +++ b/server/platform/services/sharedchannel/sync_send_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestService() *Service { + return &Service{ + changeSignal: make(chan struct{}, 1), + tasks: make(map[string]syncTask), + } +} + +func TestAddTask_OriginRemoteIDMerge(t *testing.T) { + tests := []struct { + name string + firstOrigin string + secondOrigin string + expectedOrigin string + }{ + { + name: "same remote origin is preserved", + firstOrigin: "remote-A", + secondOrigin: "remote-A", + expectedOrigin: "remote-A", + }, + { + name: "local then remote clears origin", + firstOrigin: "", + secondOrigin: "remote-A", + expectedOrigin: "", + }, + { + name: "remote then local clears origin", + firstOrigin: "remote-A", + secondOrigin: "", + expectedOrigin: "", + }, + { + name: "different remotes clears origin", + firstOrigin: "remote-A", + secondOrigin: "remote-B", + expectedOrigin: "", + }, + { + name: "both local stays empty", + firstOrigin: "", + secondOrigin: "", + expectedOrigin: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + scs := newTestService() + channelID := "channel-1" + + first := newSyncTask(channelID, "", "", nil, nil) + first.originRemoteID = tc.firstOrigin + scs.addTask(first) + + second := newSyncTask(channelID, "", "", nil, nil) + second.originRemoteID = tc.secondOrigin + scs.addTask(second) + + merged, ok := scs.tasks[first.id] + require.True(t, ok, "task should exist") + assert.Equal(t, tc.expectedOrigin, merged.originRemoteID) + }) + } +} diff --git a/server/platform/services/sharedchannel/util.go b/server/platform/services/sharedchannel/util.go index 07f7cf16713..1c027462af3 100644 --- a/server/platform/services/sharedchannel/util.go +++ b/server/platform/services/sharedchannel/util.go @@ -48,12 +48,6 @@ func sanitizeUserForSync(user *model.User) *model.User { return user } -func sanitizeUserForSyncSafe(user *model.User) *model.User { - // Create a copy to avoid modifying the original user object - userCopy := *user - return sanitizeUserForSync(&userCopy) -} - const MungUsernameSeparator = "-" // mungUsername creates a new username by combining username and remote cluster name, plus diff --git a/server/public/model/shared_channel.go b/server/public/model/shared_channel.go index cccba8ea013..f81392f9918 100644 --- a/server/public/model/shared_channel.go +++ b/server/public/model/shared_channel.go @@ -172,13 +172,12 @@ type SharedChannelRemoteStatus struct { // SharedChannelUser stores a lastSyncAt timestamp on behalf of a remote cluster for // each user that has been synchronized. type SharedChannelUser struct { - Id string `json:"id"` - UserId string `json:"user_id"` - ChannelId string `json:"channel_id"` - RemoteId string `json:"remote_id"` - CreateAt int64 `json:"create_at"` - LastSyncAt int64 `json:"last_sync_at"` - LastMembershipSyncAt int64 `json:"last_membership_sync_at"` + Id string `json:"id"` + UserId string `json:"user_id"` + ChannelId string `json:"channel_id"` + RemoteId string `json:"remote_id"` + CreateAt int64 `json:"create_at"` + LastSyncAt int64 `json:"last_sync_at"` } func (scu *SharedChannelUser) PreSave() { @@ -336,6 +335,8 @@ type SyncResponse struct { AcknowledgementErrors []string `json:"acknowledgement_errors"` StatusErrors []string `json:"status_errors"` // user IDs for which the status sync failed + + MembershipErrors []string `json:"membership_errors,omitempty"` } // RegisterPluginOpts is passed by plugins to the `RegisterPluginForSharedChannels` plugin API