From f0b2a36dbc4a9e8a5e5f6232ba51a1940e90fddf Mon Sep 17 00:00:00 2001 From: Doug Lauder Date: Mon, 23 Mar 2026 10:12:17 -0400 Subject: [PATCH] MM-67616: Refactor shared channel membership sync to use ChannelMemberHistory (#35619) * Refactor shared channel membership sync to use ChannelMemberHistory (MM-67616) Replace the trigger-time membership sync mechanism with a cursor-based approach using ChannelMemberHistory, aligning membership sync with the established pattern used by posts and reactions. Previously, membership changes were built into SyncMsg at trigger time and sent via a separate TopicChannelMembership code path. This meant removals were lost if a remote was offline, since ChannelMembers hard-deletes rows. Now, membership changes are fetched from ChannelMemberHistory at sync time using the LastMembersSyncAt cursor, detecting both joins and leaves reliably. The data flows through the normal syncForRemote pipeline alongside posts, reactions, and other sync data. Key changes: - Add GetMembershipChanges store method for ChannelMemberHistory - Add fetchMembershipsForSync and sendMembershipSyncData to sync pipeline - Replace HandleMembershipChange with NotifyMembershipChanged (trigger-only) - Remove conflict detection (idempotent add/remove resolves naturally) - Remove per-user membership tracking (GetUserChanges, UpdateUserLastMembershipSyncAt) - Add MembershipErrors to SyncResponse - Keep TopicChannelMembership receiver for one release cycle (backward compat) --- server/channels/app/channel.go | 5 +- .../platform/shared_channel_service_iface.go | 4 +- ...l_membership_sync_self_referential_test.go | 97 ++-- .../app/shared_channel_service_iface.go | 9 +- ...hannel_sync_self_referential_utils_test.go | 6 +- .../channels/store/retrylayer/retrylayer.go | 63 +-- .../sqlstore/channel_member_history_store.go | 34 ++ .../store/sqlstore/shared_channel_store.go | 22 +- .../shared_channel_store_membership.go | 28 -- server/channels/store/store.go | 3 +- .../storetest/channel_member_history_store.go | 122 +++++ .../mocks/ChannelMemberHistoryStore.go | 30 ++ .../store/storetest/mocks/ChannelStore.go | 3 +- .../store/storetest/mocks/PostStore.go | 3 +- .../storetest/mocks/SharedChannelStore.go | 48 -- .../channels/store/storetest/mocks/Store.go | 11 +- .../store/storetest/mocks/ThreadStore.go | 3 +- .../store/storetest/mocks/UserStore.go | 6 +- .../channels/store/timerlayer/timerlayer.go | 48 +- .../services/sharedchannel/channelinvite.go | 33 +- .../services/sharedchannel/membership.go | 422 +----------------- .../services/sharedchannel/membership_recv.go | 133 ++---- .../sharedchannel/membership_recv_test.go | 87 ++-- .../sharedchannel/membership_send_test.go | 362 +++++++++++++++ .../services/sharedchannel/sync_recv.go | 2 + .../services/sharedchannel/sync_send.go | 32 +- .../sharedchannel/sync_send_remote.go | 171 ++++++- .../services/sharedchannel/sync_send_test.go | 77 ++++ .../platform/services/sharedchannel/util.go | 6 - server/public/model/shared_channel.go | 15 +- 30 files changed, 1059 insertions(+), 826 deletions(-) create mode 100644 server/platform/services/sharedchannel/membership_send_test.go create mode 100644 server/platform/services/sharedchannel/sync_send_test.go diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index 79f0c795982..97e33795827 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -1772,7 +1772,7 @@ func (a *App) addUserToChannel(rctx request.CTX, user *model.User, channel *mode // Synchronize membership change for shared channels if channel.IsShared() { if scs := a.Srv().Platform().GetSharedChannelService(); scs != nil { - scs.HandleMembershipChange(channel.Id, user.Id, true, user.GetRemoteID()) + scs.NotifyMembershipChanged(channel.Id, user.GetRemoteID()) } } @@ -2844,9 +2844,8 @@ func (a *App) removeUserFromChannel(rctx request.CTX, userIDToRemove string, rem // Synchronize membership change for shared channels if channel.IsShared() { - // isAdd=false, empty remoteId means locally initiated if scs := a.Srv().Platform().GetSharedChannelService(); scs != nil { - scs.HandleMembershipChange(channel.Id, userIDToRemove, false, "") + scs.NotifyMembershipChanged(channel.Id, "") } } diff --git a/server/channels/app/platform/shared_channel_service_iface.go b/server/channels/app/platform/shared_channel_service_iface.go index 2db9152dea3..ab76329f728 100644 --- a/server/channels/app/platform/shared_channel_service_iface.go +++ b/server/channels/app/platform/shared_channel_service_iface.go @@ -24,7 +24,7 @@ type SharedChannelServiceIFace interface { CheckChannelNotShared(channelID string) error CheckChannelIsShared(channelID string) error CheckCanInviteToSharedChannel(channelId string) error - HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) + NotifyMembershipChanged(channelID string, originRemoteID string) TransformMentionsOnReceiveForTesting(rctx request.CTX, post *model.Post, targetChannel *model.Channel, rc *model.RemoteCluster, mentionTransforms map[string]string) } @@ -81,7 +81,7 @@ func (mrcs *mockSharedChannelService) NumInvitations() int { return mrcs.numInvitations } -func (mrcs *mockSharedChannelService) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { +func (mrcs *mockSharedChannelService) NotifyMembershipChanged(channelID string, originRemoteID string) { // This is a mock implementation - it doesn't need to do anything } diff --git a/server/channels/app/shared_channel_membership_sync_self_referential_test.go b/server/channels/app/shared_channel_membership_sync_self_referential_test.go index afab4acfcb1..542fdb7add9 100644 --- a/server/channels/app/shared_channel_membership_sync_self_referential_test.go +++ b/server/channels/app/shared_channel_membership_sync_self_referential_test.go @@ -29,7 +29,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { ss := th.App.Srv().Store() - // Get the shared channel service and cast to concrete type to access SyncAllChannelMembers + // Get the shared channel service and cast to concrete type scsInterface := th.App.Srv().GetSharedChannelSyncService() service, ok := scsInterface.(*sharedchannel.Service) require.True(t, ok, "Expected sharedchannel.Service concrete type") @@ -61,7 +61,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { t.Run("Test 1: Automatic sync on membership changes", func(t *testing.T) { // This test verifies that membership sync happens automatically when users are added or removed from a shared channel. - // The sync is triggered by HandleMembershipChange which is called automatically by AddUserToChannel and RemoveUserFromChannel. + // The sync is triggered by NotifyMembershipChanged which is called automatically by AddUserToChannel and RemoveUserFromChannel. // The test ensures that sync messages are sent asynchronously after a minimum delay for both add and remove operations. EnsureCleanState(t, th, ss) // Track sync messages received @@ -132,7 +132,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { _, _, appErr = th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, user.Id, th.BasicUser.Id) require.Nil(t, appErr) - // Add user to channel - this triggers HandleMembershipChange automatically + // Add user to channel - this triggers NotifyMembershipChanged automatically _, appErr = th.App.AddUserToChannel(th.Context, user, channel, false) require.Nil(t, appErr) @@ -307,8 +307,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { mu.Unlock() } - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for batch messages to be received with more robust checking require.Eventually(t, func() bool { @@ -343,11 +342,10 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } mu.Unlock() - // Verify exact batch count - assert.Equal(t, expectedBatches, actualBatches, fmt.Sprintf("Should have exactly %d batches with batch size %d", expectedBatches, batchSize)) - - // Verify total synced users - assert.Equal(t, expectedTotal, totalSynced, "All users including bots and system admins should be synced") + // With inclusive cursor (GtOrEq), boundary rows may be re-fetched across + // batch boundaries, so batch count and total may exceed the minimum. + assert.GreaterOrEqual(t, actualBatches, expectedBatches, fmt.Sprintf("Should have at least %d batches with batch size %d", expectedBatches, batchSize)) + assert.GreaterOrEqual(t, totalSynced, expectedTotal, "All users including bots and system admins should be synced") // Verify that bot and system admin WERE synced assert.Contains(t, allSyncedUserIDs, bot.UserId, "Bot should be synced") @@ -468,8 +466,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } // First sync - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync to complete require.Eventually(t, func() bool { @@ -497,8 +494,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, appErr) // Second sync - should only sync user3 - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for second sync to complete require.Eventually(t, func() bool { @@ -518,11 +514,14 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { return secondSyncCursor > firstSyncCursor }, 10*time.Second, 100*time.Millisecond, "Cursor should advance after second sync") - // Verify incremental sync + // Verify incremental sync: first sync includes initial users, second + // sync must include user3. The inclusive cursor (GtOrEq) may re-fetch + // events at the boundary timestamp, so users from the first batch can + // appear again. This is the expected trade-off to avoid data loss at + // batch boundaries without requiring a composite cursor schema change. + // The receiver is idempotent so duplicates are harmless. assert.GreaterOrEqual(t, len(syncedInFirstCall), 2, "First sync should include initial users") - assert.Contains(t, syncedInSecondCall, user3.Id, "Second sync should include only new user") - assert.NotContains(t, syncedInSecondCall, user1.Id, "Second sync should not re-sync existing users") - assert.NotContains(t, syncedInSecondCall, user2.Id, "Second sync should not re-sync existing users") + assert.Contains(t, syncedInSecondCall, user3.Id, "Second sync must include the new user") }) t.Run("Test 4: Sync failure and recovery", func(t *testing.T) { t.Skip("MM-64687") @@ -618,8 +617,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, appErr) // First sync attempt - will fail - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync attempt with more robust checking require.Eventually(t, func() bool { @@ -635,8 +633,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { failureMode.Store(false) // Second sync attempt - should succeed - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for successful sync with more robust checking require.Eventually(t, func() bool { @@ -648,7 +645,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { assert.Greater(t, finalAttempts, initialAttempts, "Should have retried after recovery") }) t.Run("Test 5: Manual sync with cursor management", func(t *testing.T) { - // This test verifies manual sync using SyncAllChannelMembers with complete cursor management: + // This test verifies manual sync using NotifyMembershipChanged with complete cursor management: // 1. Initial sync of 10 users with cursor tracking // 2. Mixed operations: remove 3 users and add 5 new users // 3. Verifies all operations are properly synced and cursor is updated correctly @@ -676,8 +673,8 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { var frame model.RemoteClusterFrame if json.Unmarshal(bodyBytes, &frame) == nil { var syncMsg model.SyncMsg - if json.Unmarshal(frame.Msg.Payload, &syncMsg) == nil && frame.Msg.Topic == "sharedchannel_membership" { - // Count membership changes from the unified field + if json.Unmarshal(frame.Msg.Payload, &syncMsg) == nil { + // Count membership changes (now sent via TopicSync) for _, change := range syncMsg.MembershipChanges { if change.IsAdd { atomic.AddInt32(&addOperations, 1) @@ -762,8 +759,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { initialCursor := initialScr.LastMembersSyncAt // Initial sync - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for initial sync to complete require.Eventually(t, func() bool { @@ -804,8 +800,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { // Sync mixed changes previousMessages := atomic.LoadInt32(&totalSyncMessages) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for mixed changes sync to complete require.Eventually(t, func() bool { @@ -963,28 +958,25 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.Nil(t, addErr) } - // Sync to all clusters - need to sync each one individually - for _, cluster := range clusters { - err = service.SyncAllChannelMembers(channel.Id, cluster.RemoteId, nil) - require.NoError(t, err) - } + // Sync to all clusters via NotifyMembershipChanged + service.NotifyMembershipChanged(channel.Id, "") // Wait for syncs to complete require.Eventually(t, func() bool { - // Each cluster should receive at least 5 sync messages (one per user) + // Each cluster should receive at least 1 sync message (membership changes are batched) for _, countPtr := range syncMessagesPerCluster { - if atomic.LoadInt32(countPtr) < 5 { + if atomic.LoadInt32(countPtr) < 1 { return false } } return true }, 10*time.Second, 100*time.Millisecond, "All clusters should receive sync messages") - // Verify each cluster received messages + // Verify each cluster received messages (membership changes are batched, so >= 1) for name, countPtr := range syncMessagesPerCluster { finalCount := atomic.LoadInt32(countPtr) - assert.GreaterOrEqual(t, finalCount, int32(5), - "Cluster %s should receive at least 5 sync messages", name) + assert.GreaterOrEqual(t, finalCount, int32(1), + "Cluster %s should receive at least 1 sync message", name) } // Part 2: Test propagation from one cluster through another @@ -1061,11 +1053,8 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { appErr = th.App.RemoveUserFromChannel(th.Context, users[0].Id, th.SystemAdminUser.Id, channel) require.Nil(t, appErr) - // Sync removal to all clusters - for _, cluster := range clusters { - err = service.SyncAllChannelMembers(channel.Id, cluster.RemoteId, nil) - require.NoError(t, err) - } + // Sync removal to all clusters via NotifyMembershipChanged + service.NotifyMembershipChanged(channel.Id, "") // Wait for removal sync require.Eventually(t, func() bool { @@ -1088,7 +1077,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { t.Run("Test 7: Feature flag disabled", func(t *testing.T) { // This test verifies that the shared channel membership sync functionality respects the feature flag. // It tests two scenarios: - // 1. When the feature flag is disabled, no sync messages should be sent even when SyncAllChannelMembers is called + // 1. When the feature flag is disabled, no sync messages should be sent even when NotifyMembershipChanged is called // 2. When the feature flag is enabled, sync messages should be sent as expected // This ensures that the feature can be safely disabled in production without triggering unintended syncs EnsureCleanState(t, th, ss) @@ -1163,8 +1152,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } atomic.StoreInt32(&syncMessageCount, 0) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify no sync messages were sent require.Never(t, func() bool { @@ -1177,8 +1165,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { }) atomic.StoreInt32(&syncMessageCount, 0) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify sync messages were sent require.Eventually(t, func() bool { @@ -1270,8 +1257,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { require.NoError(t, err) // Trigger membership sync as would happen when connection is restored - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Verify sync task was created and executed with more generous timeout require.Eventually(t, func() bool { @@ -1387,8 +1373,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { syncHandler = NewSelfReferentialSyncHandler(t, service, selfCluster) // First sync should succeed - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") // Wait for first sync with more generous timeout require.Eventually(t, func() bool { @@ -1416,8 +1401,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { } // Second sync should fail (server goes offline) - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) // Method itself shouldn't error + service.NotifyMembershipChanged(channel.Id, "") // Wait for second sync attempt with more generous timeout require.Eventually(t, func() bool { @@ -1571,8 +1555,7 @@ func TestSharedChannelMembershipSyncSelfReferential(t *testing.T) { // Sync memberships for each channel separately for _, channel := range []*model.Channel{channel1, channel2, channel3} { - err = service.SyncAllChannelMembers(channel.Id, selfCluster.RemoteId, nil) - require.NoError(t, err) + service.NotifyMembershipChanged(channel.Id, "") } // Ensure the sync handler is ready by waiting for the first message diff --git a/server/channels/app/shared_channel_service_iface.go b/server/channels/app/shared_channel_service_iface.go index 2a6ecfa682c..52e2b635eb6 100644 --- a/server/channels/app/shared_channel_service_iface.go +++ b/server/channels/app/shared_channel_service_iface.go @@ -27,7 +27,7 @@ type SharedChannelServiceIFace interface { CheckChannelNotShared(channelID string) error CheckChannelIsShared(channelID string) error CheckCanInviteToSharedChannel(channelId string) error - HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) + NotifyMembershipChanged(channelID string, originRemoteID string) IsRemoteClusterDirectlyConnected(remoteId string) bool TransformMentionsOnReceiveForTesting(rctx request.CTX, post *model.Post, targetChannel *model.Channel, rc *model.RemoteCluster, mentionTransforms map[string]string) } @@ -37,6 +37,7 @@ func NewMockSharedChannelService(service SharedChannelServiceIFace) *mockSharedC SharedChannelServiceIFace: service, channelNotifications: []string{}, userProfileNotifications: []string{}, + membershipNotifications: []string{}, numInvitations: 0, } return mrcs @@ -46,6 +47,7 @@ type mockSharedChannelService struct { SharedChannelServiceIFace channelNotifications []string userProfileNotifications []string + membershipNotifications []string numInvitations int } @@ -96,9 +98,10 @@ func (mrcs *mockSharedChannelService) NumInvitations() int { return mrcs.numInvitations } -func (mrcs *mockSharedChannelService) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { +func (mrcs *mockSharedChannelService) NotifyMembershipChanged(channelID string, originRemoteID string) { + mrcs.membershipNotifications = append(mrcs.membershipNotifications, channelID) if mrcs.SharedChannelServiceIFace != nil { - mrcs.SharedChannelServiceIFace.HandleMembershipChange(channelID, userID, isAdd, remoteID) + mrcs.SharedChannelServiceIFace.NotifyMembershipChanged(channelID, originRemoteID) } } diff --git a/server/channels/app/shared_channel_sync_self_referential_utils_test.go b/server/channels/app/shared_channel_sync_self_referential_utils_test.go index d0944f66229..1b4beeeb420 100644 --- a/server/channels/app/shared_channel_sync_self_referential_utils_test.go +++ b/server/channels/app/shared_channel_sync_self_referential_utils_test.go @@ -129,8 +129,10 @@ func (h *SelfReferentialSyncHandler) HandleRequest(w http.ResponseWriter, r *htt if h.OnBatchSync != nil { h.OnBatchSync(batch, currentCall) } - if len(batch) == 1 && h.OnIndividualSync != nil { - h.OnIndividualSync(batch[0], currentCall) + if h.OnIndividualSync != nil { + for _, uid := range batch { + h.OnIndividualSync(uid, currentCall) + } } } } diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 1978557da27..5fdd6bb9f17 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -3869,6 +3869,27 @@ func (s *RetryLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star } +func (s *RetryLayerChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + + tries := 0 + for { + result, err := s.ChannelMemberHistoryStore.GetMembershipChanges(channelID, since, limit) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { tries := 0 @@ -12707,27 +12728,6 @@ func (s *RetryLayerSharedChannelStore) GetSingleUser(userID string, channelID st } -func (s *RetryLayerSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - - tries := 0 - for { - result, err := s.SharedChannelStore.GetUserChanges(userID, channelID, afterTime) - if err == nil { - return result, nil - } - if !isRepeatableError(err) { - return result, err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return result, err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - func (s *RetryLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { tries := 0 @@ -13001,27 +13001,6 @@ func (s *RetryLayerSharedChannelStore) UpdateRemoteMembershipCursor(id string, s } -func (s *RetryLayerSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - - tries := 0 - for { - err := s.SharedChannelStore.UpdateUserLastMembershipSyncAt(userID, channelID, remoteID, syncTime) - if err == nil { - return nil - } - if !isRepeatableError(err) { - return err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - func (s *RetryLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { tries := 0 diff --git a/server/channels/store/sqlstore/channel_member_history_store.go b/server/channels/store/sqlstore/channel_member_history_store.go index 7665297c379..b4f8e6c9753 100644 --- a/server/channels/store/sqlstore/channel_member_history_store.go +++ b/server/channels/store/sqlstore/channel_member_history_store.go @@ -295,6 +295,40 @@ func (s SqlChannelMemberHistoryStore) PermanentDeleteBatch(endTime int64, limit return rowsAffected, nil } +// GetMembershipChanges returns all membership events (joins and leaves) for a channel since the given timestamp. +// Uses inclusive comparison (>=) so events at the cursor timestamp are re-fetched rather than lost at batch +// boundaries. This may cause redundant re-sends when consecutive batches share a boundary timestamp, but the +// receiver is idempotent so duplicates are harmless. A composite cursor (timestamp + ID) like posts use would +// eliminate duplicates, but would require a schema change to SharedChannelRemotes; the current trade-off avoids that. +func (s SqlChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + query, args, err := s.getQueryBuilder(). + Select("ChannelId", "UserId", "JoinTime", "LeaveTime"). + From("ChannelMemberHistory"). + Where(sq.And{ + sq.Eq{"ChannelId": channelID}, + sq.Or{ + sq.GtOrEq{"JoinTime": since}, + sq.And{ + sq.NotEq{"LeaveTime": nil}, + sq.GtOrEq{"LeaveTime": since}, + }, + }, + }). + OrderBy("GREATEST(JoinTime, COALESCE(LeaveTime, 0)) ASC", "UserId ASC"). + Limit(uint64(limit)). + ToSql() + if err != nil { + return nil, errors.Wrap(err, "channel_member_history_to_sql") + } + + histories := []*model.ChannelMemberHistory{} + if err := s.GetReplica().Select(&histories, query, args...); err != nil { + return nil, errors.Wrapf(err, "GetMembershipChanges channelId=%s since=%d limit=%d", channelID, since, limit) + } + + return histories, nil +} + // GetChannelsLeftSince returns list of channels that the user has left after a given time, // but has not rejoined again. func (s SqlChannelMemberHistoryStore) GetChannelsLeftSince(userID string, since int64) ([]string, error) { diff --git a/server/channels/store/sqlstore/shared_channel_store.go b/server/channels/store/sqlstore/shared_channel_store.go index b6047b5c3c3..f748a438c20 100644 --- a/server/channels/store/sqlstore/shared_channel_store.go +++ b/server/channels/store/sqlstore/shared_channel_store.go @@ -711,7 +711,6 @@ func sharedChannelUserFields(prefix string) []string { prefix + "RemoteId", prefix + "CreateAt", prefix + "LastSyncAt", - prefix + "LastMembershipSyncAt", } } @@ -724,7 +723,7 @@ func (s SqlSharedChannelStore) SaveUser(scUser *model.SharedChannelUser) (*model query, args, err := s.getQueryBuilder().Insert("SharedChannelUsers"). Columns(sharedChannelUserFields("")...). - Values(scUser.Id, scUser.UserId, scUser.ChannelId, scUser.RemoteId, scUser.CreateAt, scUser.LastSyncAt, scUser.LastMembershipSyncAt). + Values(scUser.Id, scUser.UserId, scUser.ChannelId, scUser.RemoteId, scUser.CreateAt, scUser.LastSyncAt). ToSql() if err != nil { return nil, errors.Wrapf(err, "savesharedchanneluser_tosql") @@ -852,25 +851,6 @@ func (s SqlSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID str return nil } -// UpdateUserLastMembershipSyncAt updates the LastMembershipSyncAt timestamp for the specified SharedChannelUser using the provided sync time. -func (s SqlSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - query := s.getQueryBuilder(). - Update("SharedChannelUsers AS scu"). - Set("LastMembershipSyncAt", sq.Expr("GREATEST(scu.LastMembershipSyncAt, ?)", syncTime)). - Where(sq.Eq{ - "scu.UserId": userID, - "scu.ChannelId": channelID, - "scu.RemoteId": remoteID, - }) - - _, err := s.GetMaster().ExecBuilder(query) - if err != nil { - return fmt.Errorf("failed to update LastMembershipSyncAt for SharedChannelUser with userId=%s, channelId=%s, remoteId=%s: %w", - userID, channelID, remoteID, err) - } - return nil -} - func sharedChannelAttachementFields(prefix string) []string { if prefix != "" && !strings.HasSuffix(prefix, ".") { prefix = prefix + "." diff --git a/server/channels/store/sqlstore/shared_channel_store_membership.go b/server/channels/store/sqlstore/shared_channel_store_membership.go index 111fbb85e01..0ce4ee44d58 100644 --- a/server/channels/store/sqlstore/shared_channel_store_membership.go +++ b/server/channels/store/sqlstore/shared_channel_store_membership.go @@ -4,10 +4,8 @@ package sqlstore import ( - "database/sql" "fmt" - "github.com/mattermost/mattermost/server/public/model" sq "github.com/mattermost/squirrel" "github.com/pkg/errors" ) @@ -38,29 +36,3 @@ func (s SqlSharedChannelStore) UpdateRemoteMembershipCursor(id string, syncTime return nil } - -// GetUserChanges gets all SharedChannelUser changes for a given user, channel after a specific time. -// This is used to detect if there are conflicting membership changes. -func (s SqlSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - squery, args, err := s.getQueryBuilder(). - Select(sharedChannelUserFields("")...). - From("SharedChannelUsers"). - Where(sq.Eq{"SharedChannelUsers.UserId": userID}). - Where(sq.Eq{"SharedChannelUsers.ChannelId": channelID}). - Where(sq.Gt{"SharedChannelUsers.LastSyncAt": afterTime}). - ToSql() - - if err != nil { - return nil, errors.Wrapf(err, "getsharedchanneluserchanges_tosql") - } - - users := []*model.SharedChannelUser{} - if err := s.GetReplica().Select(&users, squery, args...); err != nil { - if err == sql.ErrNoRows { - return make([]*model.SharedChannelUser, 0), nil - } - return nil, errors.Wrapf(err, "failed to find shared channel user changes with UserId=%s, ChannelId=%s, afterTime=%d", - userID, channelID, afterTime) - } - return users, nil -} diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 789b3236a1d..ae5a3825081 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -334,6 +334,7 @@ type ChannelMemberHistoryStore interface { DeleteOrphanedRows(limit int) (deleted int64, err error) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) GetChannelsLeftSince(userID string, since int64) ([]string, error) + GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) } type ThreadStore interface { GetThreadFollowers(threadID string, fetchOnlyActive bool) ([]string, error) @@ -1040,9 +1041,7 @@ type SharedChannelStore interface { GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) - GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error - UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) diff --git a/server/channels/store/storetest/channel_member_history_store.go b/server/channels/store/storetest/channel_member_history_store.go index aaae0e34ab1..05b339e9c88 100644 --- a/server/channels/store/storetest/channel_member_history_store.go +++ b/server/channels/store/storetest/channel_member_history_store.go @@ -26,6 +26,7 @@ func TestChannelMemberHistoryStore(t *testing.T, rctx request.CTX, ss store.Stor t.Run("TestPermanentDeleteBatchForRetentionPolicies", func(t *testing.T) { testPermanentDeleteBatchForRetentionPolicies(t, rctx, ss) }) t.Run("TestGetChannelsLeftSince", func(t *testing.T) { testGetChannelsLeftSince(t, rctx, ss) }) t.Run("TestDeleteOrphanedRows", func(t *testing.T) { testDeleteOrphanedRows(t, rctx, ss) }) + t.Run("TestGetMembershipChanges", func(t *testing.T) { testGetMembershipChanges(t, rctx, ss) }) } func testLogJoinEvent(t *testing.T, rctx request.CTX, ss store.Store) { @@ -689,3 +690,124 @@ func testDeleteOrphanedRows(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) require.Equal(t, int64(0), deletedCount, "No rows should be deleted when no orphans exist") } + +func testGetMembershipChanges(t *testing.T, rctx request.CTX, ss store.Store) { + ch := model.Channel{ + TeamId: model.NewId(), + DisplayName: "GetMembershipChanges", + Name: NewTestID(), + Type: model.ChannelTypeOpen, + } + channel, nErr := ss.Channel().Save(rctx, &ch, 100) + require.NoError(t, nErr) + + user1 := model.NewId() + user2 := model.NewId() + user3 := model.NewId() + + // Set up timeline: + // t=1000: user1 joins + // t=2000: user2 joins + // t=3000: user3 joins + // t=4000: user1 leaves + // t=5000: user2 leaves + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user1, channel.Id, 1000)) + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user2, channel.Id, 2000)) + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user3, channel.Id, 3000)) + require.NoError(t, ss.ChannelMemberHistory().LogLeaveEvent(user1, channel.Id, 4000)) + require.NoError(t, ss.ChannelMemberHistory().LogLeaveEvent(user2, channel.Id, 5000)) + + t.Run("returns all events since timestamp zero", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + assert.Len(t, results, 3, "should return all 3 users' history rows") + }) + + t.Run("filters by since timestamp - joins only", func(t *testing.T) { + // since=2500: user3 joined at 3000 (>=2500), user1 left at 4000 (>=2500), user2 left at 5000 (>=2500) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 2500, 100) + require.NoError(t, err) + assert.Len(t, results, 3, "user3 join>=2500, user1 leave>=2500, user2 leave>=2500") + }) + + t.Run("inclusive boundary includes events at exact cursor timestamp", func(t *testing.T) { + // since=3000: user3 joined at exactly 3000, should be included (GtOrEq) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 3000, 100) + require.NoError(t, err) + userIDs := make(map[string]bool) + for _, r := range results { + userIDs[r.UserId] = true + } + assert.True(t, userIDs[user3], "user3 joined at exactly since=3000, should be included") + assert.True(t, userIDs[user1], "user1 left at 4000 >= 3000") + assert.True(t, userIDs[user2], "user2 left at 5000 >= 3000") + }) + + t.Run("filters by since timestamp - leaves only", func(t *testing.T) { + // since=3500: user1 left at 4000, user2 left at 5000; user3 joined at 3000 (not > 3500) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 3500, 100) + require.NoError(t, err) + assert.Len(t, results, 2, "only user1 and user2 have events after 3500") + }) + + t.Run("respects limit parameter", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 2) + require.NoError(t, err) + assert.Len(t, results, 2, "should respect limit of 2") + }) + + t.Run("returns empty for unknown channel", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(model.NewId(), 0, 100) + require.NoError(t, err) + assert.Empty(t, results) + }) + + t.Run("returns empty when since is beyond all events", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 999999, 100) + require.NoError(t, err) + assert.Empty(t, results) + }) + + t.Run("results are ordered by greatest event time ascending", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + require.Len(t, results, 3) + + // Effective event times: user1=max(1000,4000)=4000, user2=max(2000,5000)=5000, user3=max(3000,0)=3000 + // Sorted ASC: user3(3000), user1(4000), user2(5000) + assert.Equal(t, user3, results[0].UserId, "user3 should be first (effective time 3000)") + assert.Equal(t, user1, results[1].UserId, "user1 should be second (effective time 4000)") + assert.Equal(t, user2, results[2].UserId, "user2 should be third (effective time 5000)") + }) + + t.Run("null LeaveTime handled correctly", func(t *testing.T) { + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + + for _, r := range results { + if r.UserId == user3 { + assert.Nil(t, r.LeaveTime, "user3 should have nil LeaveTime") + } + if r.UserId == user1 { + require.NotNil(t, r.LeaveTime) + assert.Equal(t, int64(4000), *r.LeaveTime) + } + } + }) + + t.Run("join-leave-rejoin cycle returns latest row", func(t *testing.T) { + // user1 rejoins at t=6000 + require.NoError(t, ss.ChannelMemberHistory().LogJoinEvent(user1, channel.Id, 6000)) + + // since=0 should now return 4 rows (user1 has 2 history rows, user2 has 1, user3 has 1) + results, err := ss.ChannelMemberHistory().GetMembershipChanges(channel.Id, 0, 100) + require.NoError(t, err) + assert.Len(t, results, 4, "user1 now has 2 history rows") + + // The rejoined row (JoinTime=6000, LeaveTime=nil) should be the last one + lastResult := results[len(results)-1] + assert.Equal(t, user1, lastResult.UserId) + assert.Equal(t, int64(6000), lastResult.JoinTime) + assert.Nil(t, lastResult.LeaveTime, "rejoin should have nil LeaveTime") + }) +} diff --git a/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go b/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go index 280090253a1..c24baf67b98 100644 --- a/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go +++ b/server/channels/store/storetest/mocks/ChannelMemberHistoryStore.go @@ -102,6 +102,36 @@ func (_m *ChannelMemberHistoryStore) GetChannelsWithActivityDuring(startTime int return r0, r1 } +// GetMembershipChanges provides a mock function with given fields: channelID, since, limit +func (_m *ChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + ret := _m.Called(channelID, since, limit) + + if len(ret) == 0 { + panic("no return value specified for GetMembershipChanges") + } + + var r0 []*model.ChannelMemberHistory + var r1 error + if rf, ok := ret.Get(0).(func(string, int64, int) ([]*model.ChannelMemberHistory, error)); ok { + return rf(channelID, since, limit) + } + if rf, ok := ret.Get(0).(func(string, int64, int) []*model.ChannelMemberHistory); ok { + r0 = rf(channelID, since, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelMemberHistory) + } + } + + if rf, ok := ret.Get(1).(func(string, int64, int) error); ok { + r1 = rf(channelID, since, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetUsersInChannelDuring provides a mock function with given fields: startTime, endTime, channelID func (_m *ChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { ret := _m.Called(startTime, endTime, channelID) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index b430c1c33d3..455c68a0448 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // ChannelStore is an autogenerated mock type for the ChannelStore type diff --git a/server/channels/store/storetest/mocks/PostStore.go b/server/channels/store/storetest/mocks/PostStore.go index ff5291cb754..d8ee218923c 100644 --- a/server/channels/store/storetest/mocks/PostStore.go +++ b/server/channels/store/storetest/mocks/PostStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // PostStore is an autogenerated mock type for the PostStore type diff --git a/server/channels/store/storetest/mocks/SharedChannelStore.go b/server/channels/store/storetest/mocks/SharedChannelStore.go index dbecc0f5b43..d0906484491 100644 --- a/server/channels/store/storetest/mocks/SharedChannelStore.go +++ b/server/channels/store/storetest/mocks/SharedChannelStore.go @@ -368,36 +368,6 @@ func (_m *SharedChannelStore) GetSingleUser(userID string, channelID string, rem return r0, r1 } -// GetUserChanges provides a mock function with given fields: userID, channelID, afterTime -func (_m *SharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - ret := _m.Called(userID, channelID, afterTime) - - if len(ret) == 0 { - panic("no return value specified for GetUserChanges") - } - - var r0 []*model.SharedChannelUser - var r1 error - if rf, ok := ret.Get(0).(func(string, string, int64) ([]*model.SharedChannelUser, error)); ok { - return rf(userID, channelID, afterTime) - } - if rf, ok := ret.Get(0).(func(string, string, int64) []*model.SharedChannelUser); ok { - r0 = rf(userID, channelID, afterTime) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*model.SharedChannelUser) - } - } - - if rf, ok := ret.Get(1).(func(string, string, int64) error); ok { - r1 = rf(userID, channelID, afterTime) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetUsersForSync provides a mock function with given fields: filter func (_m *SharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { ret := _m.Called(filter) @@ -748,24 +718,6 @@ func (_m *SharedChannelStore) UpdateRemoteMembershipCursor(id string, syncTime i return r0 } -// UpdateUserLastMembershipSyncAt provides a mock function with given fields: userID, channelID, remoteID, syncTime -func (_m *SharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - ret := _m.Called(userID, channelID, remoteID, syncTime) - - if len(ret) == 0 { - panic("no return value specified for UpdateUserLastMembershipSyncAt") - } - - var r0 error - if rf, ok := ret.Get(0).(func(string, string, string, int64) error); ok { - r0 = rf(userID, channelID, remoteID, syncTime) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // UpdateUserLastSyncAt provides a mock function with given fields: userID, channelID, remoteID func (_m *SharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { ret := _m.Called(userID, channelID, remoteID) diff --git a/server/channels/store/storetest/mocks/Store.go b/server/channels/store/storetest/mocks/Store.go index f77ba092ec9..e12d1d54782 100644 --- a/server/channels/store/storetest/mocks/Store.go +++ b/server/channels/store/storetest/mocks/Store.go @@ -5,13 +5,16 @@ package mocks import ( - sql "database/sql" - time "time" + mlog "github.com/mattermost/mattermost/server/public/shared/mlog" + mock "github.com/stretchr/testify/mock" model "github.com/mattermost/mattermost/server/public/model" - mlog "github.com/mattermost/mattermost/server/public/shared/mlog" + + sql "database/sql" + store "github.com/mattermost/mattermost/server/v8/channels/store" - mock "github.com/stretchr/testify/mock" + + time "time" ) // Store is an autogenerated mock type for the Store type diff --git a/server/channels/store/storetest/mocks/ThreadStore.go b/server/channels/store/storetest/mocks/ThreadStore.go index 4a3aa04edb0..e2d222da345 100644 --- a/server/channels/store/storetest/mocks/ThreadStore.go +++ b/server/channels/store/storetest/mocks/ThreadStore.go @@ -7,8 +7,9 @@ package mocks import ( model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // ThreadStore is an autogenerated mock type for the ThreadStore type diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index fa0670cda2a..0845e1e23ab 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -8,9 +8,11 @@ import ( context "context" model "github.com/mattermost/mattermost/server/public/model" - request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" mock "github.com/stretchr/testify/mock" + + request "github.com/mattermost/mattermost/server/public/shared/request" + + store "github.com/mattermost/mattermost/server/v8/channels/store" ) // UserStore is an autogenerated mock type for the UserStore type diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 439445b25a5..2575e55bc4d 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -3224,6 +3224,22 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star return result, err } +func (s *TimerLayerChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { + start := time.Now() + + result, err := s.ChannelMemberHistoryStore.GetMembershipChanges(channelID, since, limit) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetMembershipChanges", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelMemberHistoryStore) GetUsersInChannelDuring(startTime int64, endTime int64, channelID []string) ([]*model.ChannelMemberHistoryResult, error) { start := time.Now() @@ -10065,22 +10081,6 @@ func (s *TimerLayerSharedChannelStore) GetSingleUser(userID string, channelID st return result, err } -func (s *TimerLayerSharedChannelStore) GetUserChanges(userID string, channelID string, afterTime int64) ([]*model.SharedChannelUser, error) { - start := time.Now() - - result, err := s.SharedChannelStore.GetUserChanges(userID, channelID, afterTime) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUserChanges", success, elapsed) - } - return result, err -} - func (s *TimerLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { start := time.Now() @@ -10289,22 +10289,6 @@ func (s *TimerLayerSharedChannelStore) UpdateRemoteMembershipCursor(id string, s return err } -func (s *TimerLayerSharedChannelStore) UpdateUserLastMembershipSyncAt(userID string, channelID string, remoteID string, syncTime int64) error { - start := time.Now() - - err := s.SharedChannelStore.UpdateUserLastMembershipSyncAt(userID, channelID, remoteID, syncTime) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateUserLastMembershipSyncAt", success, elapsed) - } - return err -} - func (s *TimerLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { start := time.Now() diff --git a/server/platform/services/sharedchannel/channelinvite.go b/server/platform/services/sharedchannel/channelinvite.go index f21a393316f..1581d56c700 100644 --- a/server/platform/services/sharedchannel/channelinvite.go +++ b/server/platform/services/sharedchannel/channelinvite.go @@ -135,7 +135,6 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc } curTime := model.GetMillis() - var sharedChannelRemote *model.SharedChannelRemote if existingScr != nil { if existingScr.DeleteAt == 0 && existingScr.IsInviteConfirmed { // the shared channel remote exists and is not @@ -155,7 +154,6 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, sErr)) return } - sharedChannelRemote = existingScr } else { // the shared channel remote doesn't exists, so we create it scr := &model.SharedChannelRemote{ @@ -172,20 +170,13 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, err)) return } - sharedChannelRemote = scr } scs.NotifyChannelChanged(sc.ChannelId) scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("`%s` has been added to channel.", rc.DisplayName)) - // Sync all channel members to the remote now that the remote entry exists - if syncErr := scs.SyncAllChannelMembers(sc.ChannelId, rc.RemoteId, sharedChannelRemote); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync channel members after invite confirmation", - mlog.String("channel_id", sc.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(sc.ChannelId, "") } if rc.IsPlugin() { @@ -326,14 +317,8 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model return fmt.Errorf("cannot restore deleted shared channel remote (channel_id=%s): %w", invite.ChannelId, err) } - // Sync local channel members to the remote after restoring the shared channel - if syncErr := scs.SyncAllChannelMembers(channel.Id, rc.RemoteId, existingScr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync local channel members after restoring shared channel", - mlog.String("channel_id", channel.Id), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(channel.Id, "") } else { creatorID := channel.CreatorId if creatorID == "" { @@ -361,14 +346,8 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model return fmt.Errorf("cannot create shared channel remote (channel_id=%s): %w", invite.ChannelId, err) } - // Sync local channel members to the remote after accepting the invitation - if syncErr := scs.SyncAllChannelMembers(channel.Id, rc.RemoteId, scr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync local channel members after accepting invitation", - mlog.String("channel_id", channel.Id), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - } + // Trigger membership sync via the normal sync pipeline (reads from ChannelMemberHistory) + scs.NotifyMembershipChanged(channel.Id, "") } return nil } diff --git a/server/platform/services/sharedchannel/membership.go b/server/platform/services/sharedchannel/membership.go index fc8cb1dd373..0cbe19858b8 100644 --- a/server/platform/services/sharedchannel/membership.go +++ b/server/platform/services/sharedchannel/membership.go @@ -4,14 +4,10 @@ package sharedchannel import ( - "context" - "encoding/json" - "fmt" "time" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" - "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" ) // isChannelMemberSyncEnabled checks if the feature flag is enabled and remote cluster service is available @@ -21,226 +17,25 @@ func (scs *Service) isChannelMemberSyncEnabled() bool { return featureFlagEnabled && remoteClusterService != nil } -// queueMembershipSyncTask creates and queues a task to synchronize channel membership changes -func (scs *Service) queueMembershipSyncTask(channelID, userID, remoteID string, syncMsg *model.SyncMsg, retryMsg *model.SyncMsg) { - task := newSyncTask(channelID, userID, remoteID, syncMsg, retryMsg) +// NotifyMembershipChanged is called when users are added or removed from a shared channel. +// It triggers a sync for the channel so that membership changes are picked up from +// ChannelMemberHistory at sync time, following the same pattern as posts and reactions. +// originRemoteID identifies the remote that initiated the change, so it can be skipped +// during sync to prevent echo-back. +func (scs *Service) NotifyMembershipChanged(channelID string, originRemoteID string) { + if !scs.isChannelMemberSyncEnabled() { + return + } + task := newSyncTask(channelID, "", "", nil, nil) + task.originRemoteID = originRemoteID task.schedule = time.Now().Add(NotifyMinimumDelay) - scs.addTask(task) } -// HandleMembershipChange is called when users are added or removed from a shared channel. -// It creates a task to notify all remote clusters about the membership change. -func (scs *Service) HandleMembershipChange(channelID, userID string, isAdd bool, remoteID string) { - if !scs.isChannelMemberSyncEnabled() { - return - } - - // Create timestamp for consistent usage - changeTime := model.GetMillis() - - // Create membership change info - syncMsg := model.NewSyncMsg(channelID) - syncMsg.MembershipChanges = []*model.MembershipChangeMsg{ - { - ChannelId: channelID, - UserId: userID, - IsAdd: isAdd, - RemoteId: remoteID, // which remote initiated this change - ChangeTime: changeTime, - }, - } - - // Queue the membership change task - scs.queueMembershipSyncTask(channelID, userID, "", syncMsg, nil) -} - -// HandleMembershipBatchChange is called to process a batch of membership changes for a shared channel. -// It creates a task to notify all remote clusters about the batch membership changes. -func (scs *Service) HandleMembershipBatchChange(channelID string, userIDs []string, isAdd bool, remoteID string) { - if !scs.isChannelMemberSyncEnabled() { - return - } - - if len(userIDs) == 0 { - return - } - - // Create timestamp for consistent usage - changeTime := model.GetMillis() - - // Create sync message with membership changes - syncMsg := model.NewSyncMsg(channelID) - syncMsg.MembershipChanges = make([]*model.MembershipChangeMsg, 0, len(userIDs)) - - // Add each user to the batch - for _, userID := range userIDs { - syncMsg.MembershipChanges = append(syncMsg.MembershipChanges, &model.MembershipChangeMsg{ - ChannelId: channelID, - UserId: userID, - IsAdd: isAdd, - RemoteId: remoteID, - ChangeTime: changeTime, - }) - } - - // Queue the batch membership sync task - scs.queueMembershipSyncTask(channelID, "", "", syncMsg, nil) -} - -// SyncAllChannelMembers synchronizes channel members to a specific remote. -// This is typically called when a channel is first shared with a remote cluster. -// If remote is provided, it will be used instead of fetching from the database. -// -// When LastMembersSyncAt is non-zero, only members updated after that cursor are synced (i.e., the delta). -// When LastMembersSyncAt is zero (initial share), all current members are synced. -// -// Limitation: this function only detects membership additions and modifications, not removals. -// Channel member rows are hard-deleted from the database, so members removed while a remote was -// offline cannot be detected by this query. Removals rely on real-time HandleMembershipChange events. -func (scs *Service) SyncAllChannelMembers(channelID string, remoteID string, remote *model.SharedChannelRemote) error { - if !scs.isChannelMemberSyncEnabled() { - return nil - } - - // Verify the channel exists and is shared - if _, err := scs.server.GetStore().SharedChannel().Get(channelID); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Failed to get shared channel", - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - return fmt.Errorf("failed to get shared channel %s: %w", channelID, err) - } - - // Get the remote to ensure it exists (if not provided) - if remote == nil { - var err error - remote, err = scs.server.GetStore().SharedChannel().GetRemoteByIds(channelID, remoteID) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Failed to get remote", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Err(err), - ) - return fmt.Errorf("failed to get remote for channel %s: %w", channelID, err) - } - } - - // Use offset-based pagination to handle channels with many members - // This ensures we don't skip members when multiple members have the same LastUpdateAt timestamp - maxPerPage := scs.GetMemberSyncBatchSize() - var allMembers model.ChannelMembers - lastSyncAt := remote.LastMembersSyncAt - offset := 0 - - // Process members incrementally with offset-based pagination - for { - opts := model.ChannelMembersGetOptions{ - ChannelID: channelID, - UpdatedAfter: lastSyncAt, - Limit: maxPerPage, - Offset: offset, - } - - members, err1 := scs.server.GetStore().Channel().GetMembers(opts) - if err1 != nil { - return fmt.Errorf("failed to get members for channel %s: %w", channelID, err1) - } - - if len(members) == 0 { - break // No more members to process - } - - // Add to our collection - allMembers = append(allMembers, members...) - - // Log progress when processing large channels - if len(allMembers)%1000 == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Processing channel members in batches", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Int("processed_so_far", len(allMembers)), - ) - } - - if len(members) < maxPerPage { - break // Last page - } - - // Move to next page - offset += maxPerPage - } - - if len(allMembers) == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "No members to sync for channel", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - ) - return nil - } - - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Syncing all channel members", - mlog.String("channel_id", channelID), - mlog.String("remote_id", remoteID), - mlog.Int("member_count", len(allMembers)), - ) - - // Get batch size from config - batchSize := scs.GetMemberSyncBatchSize() - - // For small channels, queue individual membership changes - if len(allMembers) <= batchSize { - return scs.syncMembersIndividually(channelID, remoteID, allMembers, remote) - } - - // For larger channels, use batch processing - return scs.syncMembersInBatches(channelID, remoteID, allMembers, remote) -} - -// syncMembersIndividually processes each member individually -// This is more efficient for small channels -func (scs *Service) syncMembersIndividually(channelID, remoteID string, members model.ChannelMembers, remote *model.SharedChannelRemote) error { - // Queue individual membership changes for each member - for _, member := range members { - // Queue membership change for this user (isAdd=true) - scs.HandleMembershipChange(channelID, member.UserId, true, "") - } - - return nil -} - -// syncMembersInBatches processes members in batches for greater efficiency -// This is better for channels with many members -func (scs *Service) syncMembersInBatches(channelID, remoteID string, members model.ChannelMembers, remote *model.SharedChannelRemote) error { - // Get batch size from config - batchSize := scs.GetMemberSyncBatchSize() - - for i := 0; i < len(members); i += batchSize { - end := min(i+batchSize, len(members)) - - // Create a batch of members - batchMembers := members[i:end] - - // Extract user IDs from the batch - userIDs := make([]string, len(batchMembers)) - for j, member := range batchMembers { - userIDs[j] = member.UserId - } - - // Use the batch handling function to queue the changes - scs.HandleMembershipBatchChange(channelID, userIDs, true, "") - } - - return nil -} - -// ForceMembershipSyncForRemote syncs channel membership for all channels shared with the -// specified remote. Called when a remote comes back online to catch up on any membership -// changes that occurred while it was offline. -// -// Note: SyncAllChannelMembers uses the LastMembersSyncAt cursor on each SharedChannelRemote -// record, so only the membership delta since the last successful sync is sent, not the full -// member list. See SyncAllChannelMembers for known limitations regarding removed members. +// ForceMembershipSyncForRemote triggers a sync for all channels shared with the specified remote. +// Called when a remote comes back online to catch up on any membership changes that occurred +// while it was offline. The sync pipeline will read ChannelMemberHistory using the +// LastMembersSyncAt cursor to detect both additions and removals. func (scs *Service) ForceMembershipSyncForRemote(rc *model.RemoteCluster) { if !scs.isChannelMemberSyncEnabled() { return @@ -260,191 +55,6 @@ func (scs *Service) ForceMembershipSyncForRemote(rc *model.RemoteCluster) { } for _, scr := range scrs { - if syncErr := scs.SyncAllChannelMembers(scr.ChannelId, rc.RemoteId, scr); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to sync channel members for reconnected remote", - mlog.String("channel_id", scr.ChannelId), - mlog.String("remote", rc.DisplayName), - mlog.String("remoteId", rc.RemoteId), - mlog.Err(syncErr), - ) - } - } -} - -// processMembershipChange processes a channel membership change task. -// It determines which remotes should receive the update and creates tasks for each. -func (scs *Service) processMembershipChange(syncMsg *model.SyncMsg) { - if len(syncMsg.MembershipChanges) == 0 { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Invalid membership change task - no membership changes", - mlog.String("channel_id", syncMsg.ChannelId), - ) - return - } - - // Get the shared channel (to verify it exists) - _, err := scs.server.GetStore().SharedChannel().Get(syncMsg.ChannelId) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get shared channel for membership change", - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Int("change_count", len(syncMsg.MembershipChanges)), - mlog.Err(err), - ) - return - } - - // Get all remotes for this channel - remotes, err := scs.server.GetStore().SharedChannel().GetRemotes(0, 999999, model.SharedChannelRemoteFilterOpts{ - ChannelId: syncMsg.ChannelId, - }) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get shared channel remotes for membership change", - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - return - } - - // Always use batch processing for consistency (works for single or multiple changes) - scs.syncMembershipBatchToRemotes(syncMsg, remotes) -} - -// syncMembershipBatchToRemotes synchronizes membership changes (single or batch) with remote clusters. -func (scs *Service) syncMembershipBatchToRemotes(syncMsg *model.SyncMsg, remotes []*model.SharedChannelRemote) { - if len(syncMsg.MembershipChanges) == 0 { - return - } - - // Get the initiating remote ID from the first change (all should be the same) - initiatingRemoteId := "" - if len(syncMsg.MembershipChanges) > 0 { - initiatingRemoteId = syncMsg.MembershipChanges[0].RemoteId - } - - // Send to all remotes except the one that initiated this change - for _, remote := range remotes { - // Skip the remote that initiated this change to prevent loops - if remote.RemoteId == initiatingRemoteId { - continue - } - - // Get the remote cluster - rc, err := scs.server.GetStore().RemoteCluster().Get(remote.RemoteId, false) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to get remote cluster for batch membership sync", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - continue - } - - // Create a copy of the sync message to potentially add user profiles - enrichedSyncMsg := &model.SyncMsg{ - Id: syncMsg.Id, - ChannelId: syncMsg.ChannelId, - MembershipChanges: syncMsg.MembershipChanges, - Users: make(map[string]*model.User), - } - - // Add user profiles for all users being added - for _, change := range syncMsg.MembershipChanges { - if change.IsAdd { - user, pErr := scs.server.GetStore().User().Get(context.Background(), change.UserId) - if pErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Failed to get user for batch membership sync", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.String("remote_id", remote.RemoteId), - mlog.Err(pErr), - ) - continue - } - - // Check if user profile needs to be synced - doSync, _, sErr := scs.shouldUserSync(user, syncMsg.ChannelId, rc) - if sErr == nil && doSync { - enrichedSyncMsg.Users[user.Id] = sanitizeUserForSyncSafe(user) - } - } - } - - // Send message using the existing remote cluster framework - payload, err := json.Marshal(enrichedSyncMsg) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to marshal batch membership message", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - continue - } - - msg := model.RemoteClusterMsg{ - Id: model.NewId(), - Topic: TopicChannelMembership, - CreateAt: model.GetMillis(), - Payload: payload, - } - - ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) - defer cancel() - - // Define a callback function - callback := func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Error sending batch membership changes to remote", - mlog.String("remote", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Int("change_count", len(syncMsg.MembershipChanges)), - mlog.Err(err), - ) - return - } - - if resp != nil && resp.Err != "" { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Remote error when processing batch membership changes", - mlog.String("remote", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.String("remote_error", resp.Err), - ) - return - } - - // Update sync timestamps - for _, change := range syncMsg.MembershipChanges { - if err := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, remote.RemoteId, change.ChangeTime); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user membership sync timestamp in batch", - mlog.String("user_id", change.UserId), - mlog.Err(err), - ) - } - } - - // Update the cursor with the latest change time - var maxChangeTime int64 - for _, change := range syncMsg.MembershipChanges { - if change.ChangeTime > maxChangeTime { - maxChangeTime = change.ChangeTime - } - } - - if err := scs.updateMembershipSyncCursor(syncMsg.ChannelId, remote.RemoteId, maxChangeTime); err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update membership sync cursor for batch", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - } - } - - err = scs.server.GetRemoteClusterService().SendMsg(ctx, msg, rc, callback) - - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to send batch membership changes to remote", - mlog.String("remote_id", remote.RemoteId), - mlog.String("channel_id", syncMsg.ChannelId), - mlog.Err(err), - ) - } + scs.NotifyChannelChanged(scr.ChannelId) } } diff --git a/server/platform/services/sharedchannel/membership_recv.go b/server/platform/services/sharedchannel/membership_recv.go index 60d50605360..4ea92537ace 100644 --- a/server/platform/services/sharedchannel/membership_recv.go +++ b/server/platform/services/sharedchannel/membership_recv.go @@ -5,7 +5,6 @@ package sharedchannel import ( "fmt" - "strings" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" @@ -13,36 +12,21 @@ import ( "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" ) -// checkMembershipConflict checks if there are newer changes that would conflict with this one -// Returns true if this change should be skipped due to a conflict -func (scs *Service) checkMembershipConflict(userID, channelID string, changeTime int64) (bool, error) { - conflicts, err := scs.server.GetStore().SharedChannel().GetUserChanges(userID, channelID, changeTime) - if err != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to check for membership change conflicts", - mlog.String("user_id", userID), - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - return false, err - } +const ( + // Error IDs returned by the app layer that indicate an idempotent no-op. + errIDAddUserToChannelFailed = "api.channel.add_user.to.channel.failed.app_error" + errIDSaveMemberExists = "app.channel.save_member.exists.app_error" + errIDGetChannelMemberMissing = "app.channel.get_member.missing.app_error" +) - // If there are conflicting operations, the latest one wins - for _, conflict := range conflicts { - if conflict.LastMembershipSyncAt > changeTime { - scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Ignoring older membership change due to conflict", - mlog.String("user_id", userID), - mlog.String("channel_id", channelID), - mlog.Int("change_time", int(changeTime)), - mlog.Int("conflicting_time", int(conflict.LastMembershipSyncAt)), - ) - return true, nil - } - } - - return false, nil -} - -// onReceiveMembershipChanges processes channel membership changes from a remote cluster +// onReceiveMembershipChanges processes channel membership changes from a remote cluster. +// In the new model, the sender derives the authoritative net state from ChannelMemberHistory. +// Both processMemberAdd and processMemberRemove are idempotent: +// - processMemberAdd ignores "already added" errors +// - processMemberRemove ignores "not found" errors +// +// Out-of-order messages resolve naturally: if an old "add" arrives after a newer "remove", +// the sender's next sync cycle will send a corrective "remove" because the history shows the user left. func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { // Check if feature flag is enabled if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { @@ -65,16 +49,8 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model return fmt.Errorf("cannot get shared channel for membership changes: %w", err) } - // Calculate the maximum ChangeTime from all changes in the batch - var maxChangeTime int64 - for _, change := range syncMsg.MembershipChanges { - if change.ChangeTime > maxChangeTime { - maxChangeTime = change.ChangeTime - } - } - // Process each change - var successCount, skipCount, failCount int + var failCount int for _, change := range syncMsg.MembershipChanges { if change.ChannelId != syncMsg.ChannelId { @@ -87,13 +63,6 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model continue } - // Check for conflicts - shouldSkip, _ := scs.checkMembershipConflict(change.UserId, change.ChannelId, change.ChangeTime) - if shouldSkip { - skipCount++ - continue - } - // Process the membership change based on whether it's an add or remove var processErr error if change.IsAdd { @@ -102,14 +71,14 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model mlog.String("channel_id", change.ChannelId), mlog.String("remote_id", rc.RemoteId), ) - processErr = scs.processMemberAdd(change, channel, rc, maxChangeTime, syncMsg) + processErr = scs.processMemberAdd(change, channel, rc, syncMsg) } else { scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Removing user from channel from remote cluster", mlog.String("user_id", change.UserId), mlog.String("channel_id", change.ChannelId), mlog.String("remote_id", rc.RemoteId), ) - processErr = scs.processMemberRemove(change, rc, maxChangeTime) + processErr = scs.processMemberRemove(change, rc) } if processErr != nil { @@ -123,15 +92,21 @@ func (scs *Service) onReceiveMembershipChanges(syncMsg *model.SyncMsg, rc *model failCount++ continue } + } - successCount++ + if failCount > 0 { + scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Some membership changes failed", + mlog.String("channel_id", syncMsg.ChannelId), + mlog.Int("total", len(syncMsg.MembershipChanges)), + mlog.Int("failed", failCount), + ) } return nil } -// processMemberAdd handles adding a user to a channel as part of batch processing -func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel *model.Channel, rc *model.RemoteCluster, maxChangeTime int64, syncMsg *model.SyncMsg) error { +// processMemberAdd handles adding a user to a channel +func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel *model.Channel, rc *model.RemoteCluster, syncMsg *model.SyncMsg) error { rctx := request.EmptyContext(scs.server.Log()) var user *model.User var err error @@ -166,30 +141,18 @@ func (scs *Service) processMemberAdd(change *model.MembershipChangeMsg, channel // Skip team member check (true) since we already handled team membership above _, appErr := scs.app.AddUserToChannel(rctx, user, channel, true) if appErr != nil { - // Skip "already added" errors - if appErr.Error() != "api.channel.add_user.to_channel.failed.app_error" && - !strings.Contains(appErr.Error(), "channel_member_exists") { + // Skip "already added" errors — idempotent + if appErr.Id != errIDAddUserToChannelFailed && + appErr.Id != errIDSaveMemberExists { return fmt.Errorf("cannot add user to channel: %w", appErr) } - // User is already in the channel, which is fine - } - - // Update the sync status - if syncErr := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, rc.RemoteId, maxChangeTime); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user LastMembershipSyncAt after batch member add", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", change.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - // Continue despite the error - this is not critical } return nil } -// processMemberRemove handles removing a user from a channel as part of batch processing -func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *model.RemoteCluster, maxChangeTime int64) error { +// processMemberRemove handles removing a user from a channel +func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *model.RemoteCluster) error { // Get channel so we can use app layer methods properly channel, err := scs.server.GetStore().Channel().Get(change.ChannelId, true) if err != nil { @@ -198,7 +161,7 @@ func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *m mlog.String("user_id", change.UserId), mlog.Err(err), ) - // Continue anyway to update sync status - the channel might be deleted + return nil // channel might be deleted, nothing to do } rctx := request.EmptyContext(scs.server.Log()) @@ -210,32 +173,14 @@ func (scs *Service) processMemberRemove(change *model.MembershipChangeMsg, rc *m return fmt.Errorf("membership remove sync failed: %w", ErrRemoteIDMismatch) } - // Use the app layer's remove user method if channel still exists - if channel != nil { - appErr := scs.app.RemoveUserFromChannel(rctx, change.UserId, "", channel) - if appErr != nil { - // Ignore "not found" errors - the user might already be removed - if !strings.Contains(appErr.Error(), "store.sql_channel.remove_member.missing.app_error") { - scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Error removing user from channel", - mlog.String("channel_id", change.ChannelId), - mlog.String("user_id", change.UserId), - mlog.Err(appErr), - ) - // Continue anyway to update sync status - don't return error here - // to ensure sync status still gets updated - } + // Use the app layer's remove user method + appErr := scs.app.RemoveUserFromChannel(rctx, change.UserId, "", channel) + if appErr != nil { + // Ignore "not found" errors - the user might already be removed + if appErr.Id == errIDGetChannelMemberMissing { + return nil } - } - - // Update the sync status - if syncErr := scs.server.GetStore().SharedChannel().UpdateUserLastMembershipSyncAt(change.UserId, change.ChannelId, rc.RemoteId, maxChangeTime); syncErr != nil { - scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update user LastMembershipSyncAt after batch member remove", - mlog.String("user_id", change.UserId), - mlog.String("channel_id", change.ChannelId), - mlog.String("remote_id", rc.RemoteId), - mlog.Err(syncErr), - ) - // Continue despite the error - this is not critical + return fmt.Errorf("cannot remove user from channel: %w", appErr) } return nil diff --git a/server/platform/services/sharedchannel/membership_recv_test.go b/server/platform/services/sharedchannel/membership_recv_test.go index a0478ac6160..cdfedeb2e2b 100644 --- a/server/platform/services/sharedchannel/membership_recv_test.go +++ b/server/platform/services/sharedchannel/membership_recv_test.go @@ -4,6 +4,7 @@ package sharedchannel import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -73,9 +74,6 @@ func TestOnReceiveMembershipChanges_ChannelIdMismatch(t *testing.T) { err := scs.onReceiveMembershipChanges(syncMsg, rc, nil) require.NoError(t, err) // function returns nil even on individual failures - - // The conflict check should never have been called since the mismatch was caught first - mockSharedChannelStore.AssertNotCalled(t, "GetUserChanges", mock.Anything, mock.Anything, mock.Anything) } func TestProcessMemberAdd_RejectsLocalUser(t *testing.T) { @@ -94,9 +92,6 @@ func TestProcessMemberAdd_RejectsLocalUser(t *testing.T) { localUser := &model.User{Id: localUserId} mockUserStore.On("Get", mockTypeContext, localUserId).Return(localUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", localUserId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -132,9 +127,6 @@ func TestProcessMemberAdd_RejectsOtherRemoteUser(t *testing.T) { otherRemoteUser := &model.User{Id: userId, RemoteId: &otherRemoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(otherRemoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -169,12 +161,8 @@ func TestProcessMemberAdd_AllowsOwnRemoteUser(t *testing.T) { remoteUser := &model.User{Id: userId, RemoteId: &remoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - // Expect the add to proceed mockApp.On("AddUserToChannel", mockTypeReqContext, mockTypeUser, mockTypeChannel, true).Return(&model.ChannelMember{}, nil) - mockSharedChannelStore.On("UpdateUserLastMembershipSyncAt", userId, channelId, remoteId, int64(1000)).Return(nil) syncMsg := &model.SyncMsg{ ChannelId: channelId, @@ -210,9 +198,6 @@ func TestProcessMemberRemove_RejectsLocalUser(t *testing.T) { localUser := &model.User{Id: localUserId} mockUserStore.On("Get", mockTypeContext, localUserId).Return(localUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", localUserId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -248,9 +233,6 @@ func TestProcessMemberRemove_RejectsOtherRemoteUser(t *testing.T) { otherRemoteUser := &model.User{Id: userId, RemoteId: &otherRemoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(otherRemoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - syncMsg := &model.SyncMsg{ ChannelId: channelId, MembershipChanges: []*model.MembershipChangeMsg{ @@ -285,12 +267,8 @@ func TestProcessMemberRemove_AllowsOwnRemoteUser(t *testing.T) { remoteUser := &model.User{Id: userId, RemoteId: &remoteId} mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) - // No conflict - mockSharedChannelStore.On("GetUserChanges", userId, channelId, mock.AnythingOfType("int64")).Return([]*model.SharedChannelUser{}, nil) - // Expect the remove to proceed mockApp.On("RemoveUserFromChannel", mockTypeReqContext, userId, "", channel).Return(nil) - mockSharedChannelStore.On("UpdateUserLastMembershipSyncAt", userId, channelId, remoteId, int64(1000)).Return(nil) syncMsg := &model.SyncMsg{ ChannelId: channelId, @@ -333,7 +311,7 @@ func TestProcessMemberAdd_RejectsLocalUser_ErrorMessage(t *testing.T) { ChannelId: channelId, } - err := scs.processMemberAdd(change, channel, rc, 1000, syncMsg) + err := scs.processMemberAdd(change, channel, rc, syncMsg) require.Error(t, err) assert.ErrorIs(t, err, ErrRemoteIDMismatch) assert.Contains(t, err.Error(), "membership add sync failed") @@ -360,8 +338,67 @@ func TestProcessMemberRemove_RejectsLocalUser_ErrorMessage(t *testing.T) { ChangeTime: 1000, } - err := scs.processMemberRemove(change, rc, 1000) + err := scs.processMemberRemove(change, rc) require.Error(t, err) assert.ErrorIs(t, err, ErrRemoteIDMismatch) assert.Contains(t, err.Error(), "membership remove sync failed") } + +func TestProcessMemberAdd_IdempotentForExistingMember(t *testing.T) { + scs, _, mockApp, _, _, _, mockUserStore := setupMembershipTest(t) + + channelId := model.NewId() + remoteId := model.NewId() + userId := model.NewId() + rc := &model.RemoteCluster{RemoteId: remoteId} + channel := &model.Channel{Id: channelId, Type: model.ChannelTypeOpen} + + // User belongs to the sending remote + remoteUser := &model.User{Id: userId, RemoteId: &remoteId} + mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) + + // Simulate "already added" error from AddUserToChannel + alreadyAddedErr := model.NewAppError("AddUserToChannel", errIDSaveMemberExists, nil, "", http.StatusBadRequest) + mockApp.On("AddUserToChannel", mockTypeReqContext, mockTypeUser, mockTypeChannel, true).Return(nil, alreadyAddedErr) + + syncMsg := &model.SyncMsg{ChannelId: channelId} + change := &model.MembershipChangeMsg{ + ChannelId: channelId, + UserId: userId, + IsAdd: true, + ChangeTime: 1000, + } + + err := scs.processMemberAdd(change, channel, rc, syncMsg) + require.NoError(t, err, "should succeed for already-added user (idempotent)") +} + +func TestProcessMemberRemove_IdempotentForNonMember(t *testing.T) { + scs, _, mockApp, _, _, mockChannelStore, mockUserStore := setupMembershipTest(t) + + channelId := model.NewId() + remoteId := model.NewId() + userId := model.NewId() + rc := &model.RemoteCluster{RemoteId: remoteId} + channel := &model.Channel{Id: channelId, Type: model.ChannelTypeOpen} + + mockChannelStore.On("Get", channelId, true).Return(channel, nil) + + // User belongs to the sending remote + remoteUser := &model.User{Id: userId, RemoteId: &remoteId} + mockUserStore.On("Get", mockTypeContext, userId).Return(remoteUser, nil) + + // Simulate "not found" error from RemoveUserFromChannel + notFoundErr := model.NewAppError("RemoveUserFromChannel", errIDGetChannelMemberMissing, nil, "", http.StatusNotFound) + mockApp.On("RemoveUserFromChannel", mockTypeReqContext, userId, "", channel).Return(notFoundErr) + + change := &model.MembershipChangeMsg{ + ChannelId: channelId, + UserId: userId, + IsAdd: false, + ChangeTime: 1000, + } + + err := scs.processMemberRemove(change, rc) + require.NoError(t, err, "should succeed for already-removed user (idempotent)") +} diff --git a/server/platform/services/sharedchannel/membership_send_test.go b/server/platform/services/sharedchannel/membership_send_test.go new file mode 100644 index 00000000000..9e683787c30 --- /dev/null +++ b/server/platform/services/sharedchannel/membership_send_test.go @@ -0,0 +1,362 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest/mock" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks" +) + +func setupSendTest(t *testing.T, enableMemberSync bool) (*Service, *mocks.Store, *mocks.ChannelMemberHistoryStore, *mocks.SharedChannelStore, *mocks.UserStore, *mocks.RemoteClusterStore) { + t.Helper() + + mockServer := &MockServerIface{} + logger := mlog.CreateConsoleTestLogger(t) + mockServer.On("Log").Return(logger) + mockServer.On("GetMetrics").Return(nil) + mockServer.On("GetRemoteClusterService").Return(nil) + + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + mockCMHStore := &mocks.ChannelMemberHistoryStore{} + mockSharedChannelStore := &mocks.SharedChannelStore{} + mockUserStore := &mocks.UserStore{} + mockRCStore := &mocks.RemoteClusterStore{} + + mockStore.On("ChannelMemberHistory").Return(mockCMHStore) + mockStore.On("SharedChannel").Return(mockSharedChannelStore) + mockStore.On("User").Return(mockUserStore) + mockStore.On("RemoteCluster").Return(mockRCStore) + mockServer.On("GetStore").Return(mockStore) + + mockConfig := model.Config{} + mockConfig.SetDefaults() + mockConfig.FeatureFlags.EnableSharedChannelsMemberSync = enableMemberSync + mockServer.On("Config").Return(&mockConfig) + + return scs, mockStore, mockCMHStore, mockSharedChannelStore, mockUserStore, mockRCStore +} + +func TestFetchMembershipsForSync_FeatureFlagDisabled(t *testing.T) { + scs, _, _, _, _, _ := setupSendTest(t, false) + + sd := &syncData{ + task: syncTask{channelID: model.NewId()}, + rc: &model.RemoteCluster{RemoteId: model.NewId()}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.Empty(t, sd.membershipChanges, "should not fetch when feature flag disabled") +} + +func TestFetchMembershipsForSync_NoChanges(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: model.NewId()}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.Empty(t, sd.membershipChanges) + assert.Equal(t, int64(0), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_DeduplicatesJoinLeaveRejoin(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + userID := model.NewId() + + // User joined at 1000, left at 2000, rejoined at 3000 + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: userID, JoinTime: 1000, LeaveTime: model.NewPointer(int64(2000))}, + {ChannelId: channelID, UserId: userID, JoinTime: 3000, LeaveTime: nil}, + }, nil) + + // User profile for the add — local user (no RemoteId) so shouldUserSync includes it + user := &model.User{Id: userID} + mockUserStore.On("Get", mock.Anything, userID).Return(user, nil) + + // shouldUserSync needs SharedChannelUser lookup + mockSCStore.On("GetSingleUser", userID, channelID, remoteID). + Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.AnythingOfType("*model.SharedChannelUser")). + Return(&model.SharedChannelUser{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + // Should deduplicate to 1 entry: user is currently a member (rejoin at 3000 > leave at 2000) + require.Len(t, sd.membershipChanges, 1) + assert.Equal(t, userID, sd.membershipChanges[0].UserId) + assert.True(t, sd.membershipChanges[0].IsAdd, "user rejoined so should be IsAdd=true") + assert.Equal(t, int64(3000), sd.membershipChanges[0].ChangeTime) + assert.Equal(t, int64(3000), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_DeduplicatesJoinThenLeave(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + userID := model.NewId() + + // User joined at 1000, then left at 2000 + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: userID, JoinTime: 1000, LeaveTime: model.NewPointer(int64(2000))}, + }, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + require.Len(t, sd.membershipChanges, 1) + assert.False(t, sd.membershipChanges[0].IsAdd, "user left so should be IsAdd=false") + assert.Equal(t, int64(2000), sd.membershipChanges[0].ChangeTime) + assert.Equal(t, int64(2000), sd.resultNextMembershipCursor) +} + +func TestFetchMembershipsForSync_MultipleUsers(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + user1 := model.NewId() + user2 := model.NewId() + user3 := model.NewId() + + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{ + {ChannelId: channelID, UserId: user1, JoinTime: 1000, LeaveTime: nil}, // joined, still member + {ChannelId: channelID, UserId: user2, JoinTime: 2000, LeaveTime: model.NewPointer(int64(3000))}, // joined then left + {ChannelId: channelID, UserId: user3, JoinTime: 4000, LeaveTime: nil}, // joined, still member + }, nil) + + // User profiles for adds (user1 and user3) — users are local (no RemoteId) + // so shouldUserSync won't filter them out for the target remote + for _, uid := range []string{user1, user3} { + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + + assert.Len(t, sd.membershipChanges, 3) + + // Build a map for easy lookup + changeMap := make(map[string]*model.MembershipChangeMsg) + for _, mc := range sd.membershipChanges { + changeMap[mc.UserId] = mc + } + + assert.True(t, changeMap[user1].IsAdd, "user1 is still a member") + assert.False(t, changeMap[user2].IsAdd, "user2 left") + assert.True(t, changeMap[user3].IsAdd, "user3 is still a member") + + // Max cursor should be 4000 (user3 join) + assert.Equal(t, int64(4000), sd.resultNextMembershipCursor) + + // Only user1 and user3 should be in the users map (adds only) + assert.Contains(t, sd.users, user1) + assert.NotContains(t, sd.users, user2, "user2 left, should not be in users map") + assert.Contains(t, sd.users, user3) +} + +func TestFetchMembershipsForSync_SetsRepeatWhenLimitHit(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + + // Return exactly as many results as the limit (batch size) + batchSize := scs.GetMemberSyncBatchSize() + histories := make([]*model.ChannelMemberHistory, batchSize) + for i := range batchSize { + uid := model.NewId() + histories[i] = &model.ChannelMemberHistory{ + ChannelId: channelID, + UserId: uid, + JoinTime: int64(1000 + i), + LeaveTime: nil, + } + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), batchSize). + Return(histories, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.True(t, sd.resultRepeat, "should set resultRepeat when limit is hit") +} + +func TestFetchMembershipsForSync_CursorFromSCR(t *testing.T) { + scs, _, mockCMHStore, _, _, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + + // Verify that LastMembersSyncAt from SCR is used as the cursor + mockCMHStore.On("GetMembershipChanges", channelID, int64(5000), mock.AnythingOfType("int")). + Return([]*model.ChannelMemberHistory{}, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 5000}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + mockCMHStore.AssertCalled(t, "GetMembershipChanges", channelID, int64(5000), mock.AnythingOfType("int")) +} + +func TestFetchMembershipsForSync_RepeatedTimestampsAtBoundary(t *testing.T) { + scs, _, mockCMHStore, mockSCStore, mockUserStore, _ := setupSendTest(t, true) + + channelID := model.NewId() + remoteID := model.NewId() + batchSize := scs.GetMemberSyncBatchSize() + + // Create batch where the last entries share the same JoinTime (boundary timestamp) + boundaryTime := int64(5000) + histories := make([]*model.ChannelMemberHistory, batchSize) + for i := range batchSize { + uid := model.NewId() + joinTime := int64(1000 + i) + // Last 3 entries share the same timestamp + if i >= batchSize-3 { + joinTime = boundaryTime + } + histories[i] = &model.ChannelMemberHistory{ + ChannelId: channelID, + UserId: uid, + JoinTime: joinTime, + LeaveTime: nil, + } + u := &model.User{Id: uid} + mockUserStore.On("Get", mock.Anything, uid).Return(u, nil) + mockSCStore.On("GetSingleUser", uid, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == uid })). + Return(&model.SharedChannelUser{}, nil) + } + + // First fetch returns full batch (hits limit) + mockCMHStore.On("GetMembershipChanges", channelID, int64(0), batchSize). + Return(histories, nil) + + sd := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: 0}, + users: make(map[string]*model.User), + } + + err := scs.fetchMembershipsForSync(sd) + require.NoError(t, err) + assert.True(t, sd.resultRepeat, "should signal more data when limit hit") + assert.Len(t, sd.membershipChanges, batchSize, "all entries in batch should be processed") + assert.Equal(t, boundaryTime, sd.resultNextMembershipCursor, "cursor should be at boundary timestamp") + + // Second fetch uses the boundary timestamp as cursor (GtOrEq means events + // AT boundaryTime will be re-fetched, but deduplication handles this) + extraUser := model.NewId() + mockCMHStore.On("GetMembershipChanges", channelID, boundaryTime, batchSize). + Return([]*model.ChannelMemberHistory{ + // Re-fetched rows at boundary (already processed — dedup will handle) + histories[batchSize-3], + histories[batchSize-2], + histories[batchSize-1], + // New row beyond the boundary + {ChannelId: channelID, UserId: extraUser, JoinTime: boundaryTime + 1, LeaveTime: nil}, + }, nil) + + extraUserObj := &model.User{Id: extraUser} + mockUserStore.On("Get", mock.Anything, extraUser).Return(extraUserObj, nil) + mockSCStore.On("GetSingleUser", extraUser, channelID, remoteID).Return(nil, ¬FoundError{}) + mockSCStore.On("SaveUser", mock.MatchedBy(func(scu *model.SharedChannelUser) bool { return scu.UserId == extraUser })). + Return(&model.SharedChannelUser{}, nil) + + sd2 := &syncData{ + task: syncTask{channelID: channelID}, + rc: &model.RemoteCluster{RemoteId: remoteID}, + scr: &model.SharedChannelRemote{LastMembersSyncAt: boundaryTime}, + users: make(map[string]*model.User), + } + + err = scs.fetchMembershipsForSync(sd2) + require.NoError(t, err) + // Should have 4 entries (3 re-fetched + 1 new) — dedup at the caller level + // handles duplicate user IDs across batches + assert.Len(t, sd2.membershipChanges, 4, "should include re-fetched boundary rows and new row") + assert.Equal(t, boundaryTime+1, sd2.resultNextMembershipCursor, "cursor should advance past boundary") +} + +// notFoundError implements the errNotFound interface used by the sharedchannel package +type notFoundError struct{} + +func (e *notFoundError) Error() string { return "not found" } +func (e *notFoundError) IsErrNotFound() bool { return true } diff --git a/server/platform/services/sharedchannel/sync_recv.go b/server/platform/services/sharedchannel/sync_recv.go index a182eb34e2d..2ab2b90dadf 100644 --- a/server/platform/services/sharedchannel/sync_recv.go +++ b/server/platform/services/sharedchannel/sync_recv.go @@ -87,6 +87,7 @@ func (scs *Service) processSyncMessage(rctx request.CTX, syncMsg *model.SyncMsg, PostErrors: make([]string, 0), ReactionErrors: make([]string, 0), AcknowledgementErrors: make([]string, 0), + MembershipErrors: make([]string, 0), } // Check if feature flag is enabled for membership changes @@ -281,6 +282,7 @@ func (scs *Service) processSyncMessage(rctx request.CTX, syncMsg *model.SyncMsg, mlog.Int("change_count", len(syncMsg.MembershipChanges)), mlog.Err(err), ) + syncResp.MembershipErrors = append(syncResp.MembershipErrors, err.Error()) // Don't fail the entire sync if membership changes fail } } diff --git a/server/platform/services/sharedchannel/sync_send.go b/server/platform/services/sharedchannel/sync_send.go index 6b7496a7f60..602c40e23e2 100644 --- a/server/platform/services/sharedchannel/sync_send.go +++ b/server/platform/services/sharedchannel/sync_send.go @@ -29,6 +29,9 @@ type syncTask struct { retryCount int retryMsg *model.SyncMsg schedule time.Time + // originRemoteID is the remote that initiated this change; it will be + // skipped when syncing to prevent echo-back. + originRemoteID string } func newSyncTask(channelID, userID string, remoteID string, existingMsg, retryMsg *model.SyncMsg) syncTask { @@ -250,6 +253,17 @@ func (scs *Service) addTask(task syncTask) { // if the task was already scheduled, we only update the // existingMsg in case there is new information originalTask.existingMsg = task.existingMsg + + // originRemoteID identifies which remote initiated a change so processTask + // can skip sending back to that remote. When multiple events merge within + // the NotifyMinimumDelay window we can only safely skip a remote if every + // merged event came from that same remote. If the origins differ (e.g. + // remote-A join + remote-B join, or remote join + local join) we must clear + // originRemoteID so the sync fans out to all remotes. The receiver is + // idempotent, so the worst case is a redundant sync to the originating remote. + if task.originRemoteID != originalTask.originRemoteID { + originalTask.originRemoteID = "" + } scs.tasks[task.id] = originalTask } else { scs.tasks[task.id] = task @@ -375,16 +389,6 @@ func (scs *Service) removeOldestTask() (syncTask, bool, time.Duration) { // processTask updates one or more remote clusters with any new channel content. func (scs *Service) processTask(task syncTask) error { - // Check if this is a membership change task - if task.existingMsg != nil && len(task.existingMsg.MembershipChanges) > 0 { - // Check if feature flag is enabled - if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { - return nil - } - scs.processMembershipChange(task.existingMsg) - return nil - } - // map is used to ensure remotes don't get sync'd twice, such as when // they have the autoinvited flag and have explicitly subscribed to a channel. remotesMap := make(map[string]*model.RemoteCluster) @@ -399,6 +403,10 @@ func (scs *Service) processTask(task syncTask) error { return err } for _, r := range remotes { + // Skip the remote that originated this membership change + if task.originRemoteID != "" && r.RemoteId == task.originRemoteID { + continue + } remotesMap[r.RemoteId] = r } @@ -412,6 +420,10 @@ func (scs *Service) processTask(task syncTask) error { return err } for _, r := range remotesAutoInvited { + // Skip the remote that originated this membership change + if task.originRemoteID != "" && r.RemoteId == task.originRemoteID { + continue + } remotesMap[r.RemoteId] = r } } else { diff --git a/server/platform/services/sharedchannel/sync_send_remote.go b/server/platform/services/sharedchannel/sync_send_remote.go index f28479e121c..9cf04920c2d 100644 --- a/server/platform/services/sharedchannel/sync_send_remote.go +++ b/server/platform/services/sharedchannel/sync_send_remote.go @@ -41,6 +41,9 @@ type syncData struct { attachments []attachment mentionTransforms map[string]string + membershipChanges []*model.MembershipChangeMsg + resultNextMembershipCursor int64 + resultRepeat bool resultNextCursor model.GetPostsSinceForSyncCursor GlobalUserSyncLastTimestamp int64 @@ -62,7 +65,7 @@ func newSyncData(task syncTask, rc *model.RemoteCluster, scr *model.SharedChanne } func (sd *syncData) isEmpty() bool { - return len(sd.users) == 0 && len(sd.profileImages) == 0 && len(sd.posts) == 0 && len(sd.reactions) == 0 && len(sd.acknowledgements) == 0 && len(sd.attachments) == 0 + return len(sd.users) == 0 && len(sd.profileImages) == 0 && len(sd.posts) == 0 && len(sd.reactions) == 0 && len(sd.acknowledgements) == 0 && len(sd.attachments) == 0 && len(sd.membershipChanges) == 0 } func (sd *syncData) isCursorChanged() bool { @@ -171,6 +174,11 @@ func (scs *Service) syncForRemote(task syncTask, rc *model.RemoteCluster) error return fmt.Errorf("cannot fetch posts for sync %v: %w", sd, err) } + // fetch membership changes from ChannelMemberHistory + if err := scs.fetchMembershipsForSync(sd); err != nil { + return fmt.Errorf("cannot fetch memberships for sync %v: %w", sd, err) + } + if !rc.IsOnline() { if len(sd.posts) != 0 { scs.notifyRemoteOffline(sd.posts, rc) @@ -342,6 +350,158 @@ func (scs *Service) fetchPostsForSync(sd *syncData) error { return nil } +// fetchMembershipsForSync populates the sync data with membership changes from ChannelMemberHistory. +func (scs *Service) fetchMembershipsForSync(sd *syncData) error { + if !scs.server.Config().FeatureFlags.EnableSharedChannelsMemberSync { + return nil + } + + start := time.Now() + defer func() { + if metrics := scs.server.GetMetrics(); metrics != nil { + metrics.ObserveSharedChannelsSyncCollectionStepDuration(sd.rc.RemoteId, "Memberships", time.Since(start).Seconds()) + } + }() + + cursor := sd.scr.LastMembersSyncAt + limit := scs.GetMemberSyncBatchSize() + + histories, err := scs.server.GetStore().ChannelMemberHistory().GetMembershipChanges(sd.task.channelID, cursor, limit) + if err != nil { + return fmt.Errorf("could not fetch membership changes for sync: %w", err) + } + + if len(histories) == 0 { + return nil + } + + // Deduplicate by user — for users with multiple events (join/leave/rejoin cycles), + // resolve to the most recent event to determine current state. + type userState struct { + isAdd bool + eventTime int64 // max of JoinTime/LeaveTime for this user + } + byUser := make(map[string]*userState) + + var maxCursor int64 + for _, h := range histories { + // Track the max event time for the next cursor + if h.JoinTime > maxCursor { + maxCursor = h.JoinTime + } + if h.LeaveTime != nil && *h.LeaveTime > maxCursor { + maxCursor = *h.LeaveTime + } + + // Determine this event's effective time and whether it represents a join or leave + var eventTime int64 + var isAdd bool + if h.LeaveTime == nil || h.JoinTime > *h.LeaveTime { + // User is currently a member (no leave, or rejoined after leaving) + isAdd = true + eventTime = h.JoinTime + } else { + // User left + isAdd = false + eventTime = *h.LeaveTime + } + + // Keep only the most recent state per user + if existing, ok := byUser[h.UserId]; !ok || eventTime > existing.eventTime { + byUser[h.UserId] = &userState{isAdd: isAdd, eventTime: eventTime} + } + } + + // Build MembershipChangeMsg entries + for userID, state := range byUser { + if state.isAdd { + user, userErr := scs.server.GetStore().User().Get(context.Background(), userID) + if userErr != nil { + scs.server.Log().Log(mlog.LvlSharedChannelServiceWarn, "Failed to get user for membership sync", + mlog.String("user_id", userID), + mlog.String("channel_id", sd.task.channelID), + mlog.Err(userErr), + ) + continue + } + + sd.membershipChanges = append(sd.membershipChanges, &model.MembershipChangeMsg{ + ChannelId: sd.task.channelID, + UserId: userID, + IsAdd: true, + ChangeTime: state.eventTime, + }) + + doSync, _, syncErr := scs.shouldUserSync(user, sd.task.channelID, sd.rc) + if syncErr == nil && doSync { + sd.users[user.Id] = user + } + } else { + sd.membershipChanges = append(sd.membershipChanges, &model.MembershipChangeMsg{ + ChannelId: sd.task.channelID, + UserId: userID, + IsAdd: false, + ChangeTime: state.eventTime, + }) + } + } + + sd.resultNextMembershipCursor = maxCursor + + // If we hit the limit, there may be more data to fetch + if len(histories) >= limit { + sd.resultRepeat = true + } + + return nil +} + +// sendMembershipSyncData sends the collected membership changes to the remote cluster. +func (scs *Service) sendMembershipSyncData(sd *syncData) error { + start := time.Now() + defer func() { + if metrics := scs.server.GetMetrics(); metrics != nil { + metrics.ObserveSharedChannelsSyncSendStepDuration(sd.rc.RemoteId, "Memberships", time.Since(start).Seconds()) + } + }() + + msg := model.NewSyncMsg(sd.task.channelID) + msg.MembershipChanges = sd.membershipChanges + + // Include only user profiles for users referenced by membership adds, + // not the full sd.users map which may contain post authors and other users. + memberUsers := make(map[string]*model.User) + for _, mc := range sd.membershipChanges { + if mc.IsAdd { + if u, ok := sd.users[mc.UserId]; ok { + memberUsers[mc.UserId] = u + } + } + } + msg.Users = memberUsers + + return scs.sendSyncMsgToRemote(msg, sd.rc, func(syncResp model.SyncResponse, errResp error) { + if len(syncResp.MembershipErrors) != 0 { + scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Response indicates error for membership(s) sync", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Array("membership_errors", syncResp.MembershipErrors), + ) + } + + // Update the membership cursor on success + if errResp == nil && sd.resultNextMembershipCursor > 0 { + if err := scs.updateMembershipSyncCursor(sd.task.channelID, sd.rc.RemoteId, sd.resultNextMembershipCursor); err != nil { + scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Failed to update membership sync cursor", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Err(err), + ) + } + } + }) +} + func appendPosts(dest []*model.Post, posts []*model.Post, postStore store.PostStore, timestamp int64, logger mlog.LoggerIFace) []*model.Post { // Append the posts individually, checking for root posts that might appear later in the list. // This is due to the UpdateAt collision handling algorithm where the order of posts is not based @@ -576,7 +736,7 @@ func (scs *Service) filterPostsForSync(sd *syncData) { // sendSyncData sends all the collected users, posts, reactions, images, and attachments to the // remote cluster. -// The order of items sent is important: users -> attachments -> posts -> reactions -> profile images +// The order of items sent is important: users -> memberships -> attachments -> posts -> reactions -> profile images func (scs *Service) sendSyncData(sd *syncData) error { start := time.Now() defer func() { @@ -595,6 +755,13 @@ func (scs *Service) sendSyncData(sd *syncData) error { } } + // send membership changes (after users so the remote has profiles for added members) + if len(sd.membershipChanges) != 0 { + if err := scs.sendMembershipSyncData(sd); err != nil { + merr.Append(fmt.Errorf("cannot send membership sync data: %w", err)) + } + } + // send attachments if len(sd.attachments) != 0 { scs.sendAttachmentSyncData(sd) diff --git a/server/platform/services/sharedchannel/sync_send_test.go b/server/platform/services/sharedchannel/sync_send_test.go new file mode 100644 index 00000000000..76016881faf --- /dev/null +++ b/server/platform/services/sharedchannel/sync_send_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestService() *Service { + return &Service{ + changeSignal: make(chan struct{}, 1), + tasks: make(map[string]syncTask), + } +} + +func TestAddTask_OriginRemoteIDMerge(t *testing.T) { + tests := []struct { + name string + firstOrigin string + secondOrigin string + expectedOrigin string + }{ + { + name: "same remote origin is preserved", + firstOrigin: "remote-A", + secondOrigin: "remote-A", + expectedOrigin: "remote-A", + }, + { + name: "local then remote clears origin", + firstOrigin: "", + secondOrigin: "remote-A", + expectedOrigin: "", + }, + { + name: "remote then local clears origin", + firstOrigin: "remote-A", + secondOrigin: "", + expectedOrigin: "", + }, + { + name: "different remotes clears origin", + firstOrigin: "remote-A", + secondOrigin: "remote-B", + expectedOrigin: "", + }, + { + name: "both local stays empty", + firstOrigin: "", + secondOrigin: "", + expectedOrigin: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + scs := newTestService() + channelID := "channel-1" + + first := newSyncTask(channelID, "", "", nil, nil) + first.originRemoteID = tc.firstOrigin + scs.addTask(first) + + second := newSyncTask(channelID, "", "", nil, nil) + second.originRemoteID = tc.secondOrigin + scs.addTask(second) + + merged, ok := scs.tasks[first.id] + require.True(t, ok, "task should exist") + assert.Equal(t, tc.expectedOrigin, merged.originRemoteID) + }) + } +} diff --git a/server/platform/services/sharedchannel/util.go b/server/platform/services/sharedchannel/util.go index 07f7cf16713..1c027462af3 100644 --- a/server/platform/services/sharedchannel/util.go +++ b/server/platform/services/sharedchannel/util.go @@ -48,12 +48,6 @@ func sanitizeUserForSync(user *model.User) *model.User { return user } -func sanitizeUserForSyncSafe(user *model.User) *model.User { - // Create a copy to avoid modifying the original user object - userCopy := *user - return sanitizeUserForSync(&userCopy) -} - const MungUsernameSeparator = "-" // mungUsername creates a new username by combining username and remote cluster name, plus diff --git a/server/public/model/shared_channel.go b/server/public/model/shared_channel.go index cccba8ea013..f81392f9918 100644 --- a/server/public/model/shared_channel.go +++ b/server/public/model/shared_channel.go @@ -172,13 +172,12 @@ type SharedChannelRemoteStatus struct { // SharedChannelUser stores a lastSyncAt timestamp on behalf of a remote cluster for // each user that has been synchronized. type SharedChannelUser struct { - Id string `json:"id"` - UserId string `json:"user_id"` - ChannelId string `json:"channel_id"` - RemoteId string `json:"remote_id"` - CreateAt int64 `json:"create_at"` - LastSyncAt int64 `json:"last_sync_at"` - LastMembershipSyncAt int64 `json:"last_membership_sync_at"` + Id string `json:"id"` + UserId string `json:"user_id"` + ChannelId string `json:"channel_id"` + RemoteId string `json:"remote_id"` + CreateAt int64 `json:"create_at"` + LastSyncAt int64 `json:"last_sync_at"` } func (scu *SharedChannelUser) PreSave() { @@ -336,6 +335,8 @@ type SyncResponse struct { AcknowledgementErrors []string `json:"acknowledgement_errors"` StatusErrors []string `json:"status_errors"` // user IDs for which the status sync failed + + MembershipErrors []string `json:"membership_errors,omitempty"` } // RegisterPluginOpts is passed by plugins to the `RegisterPluginForSharedChannels` plugin API