diff --git a/server/channels/api4/websocket_test.go b/server/channels/api4/websocket_test.go index 13ee33a28d0..ecd9bb38cdf 100644 --- a/server/channels/api4/websocket_test.go +++ b/server/channels/api4/websocket_test.go @@ -441,7 +441,12 @@ func TestWebSocketPresence(t *testing.T) { require.Nil(t, resp.Error) require.Equal(t, resp.SeqReply, wsClient.Sequence-1, "bad sequence number") - wsClient.UpdateActiveThread("threadID") + wsClient.UpdateActiveThread(true, "threadID") + resp = <-wsClient.ResponseChannel + require.Nil(t, resp.Error) + require.Equal(t, resp.SeqReply, wsClient.Sequence-1, "bad sequence number") + + wsClient.UpdateActiveThread(false, "threadID") resp = <-wsClient.ResponseChannel require.Nil(t, resp.Error) require.Equal(t, resp.SeqReply, wsClient.Sequence-1, "bad sequence number") diff --git a/server/channels/app/platform/web_conn.go b/server/channels/app/platform/web_conn.go index 179614a2528..bbb62205e37 100644 --- a/server/channels/app/platform/web_conn.go +++ b/server/channels/app/platform/web_conn.go @@ -105,16 +105,17 @@ type WebConn struct { // a reused connection. // It's theoretically possible for this number to wrap around. But we // leave that as an edge-case. - reuseCount int - sessionToken atomic.Value - session atomic.Pointer[model.Session] - connectionID atomic.Value - activeChannelID atomic.Value - activeTeamID atomic.Value - activeThreadChannelID atomic.Value - endWritePump chan struct{} - pumpFinished chan struct{} - pluginPosted chan pluginWSPostedHook + reuseCount int + sessionToken atomic.Value + session atomic.Pointer[model.Session] + connectionID atomic.Value + activeChannelID atomic.Value + activeTeamID atomic.Value + activeRHSThreadChannelID atomic.Value + activeThreadViewThreadChannelID atomic.Value + endWritePump chan struct{} + pumpFinished chan struct{} + pluginPosted chan pluginWSPostedHook // These counters are to suppress spammy websocket.slow // and websocket.full logs which happen continuously, if they @@ -237,7 +238,8 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn wc.SetConnectionID(cfg.ConnectionID) wc.SetActiveChannelID("") wc.SetActiveTeamID("") - wc.SetActiveThreadChannelID("") + wc.SetActiveRHSThreadChannelID("") + wc.SetActiveThreadViewThreadChannelID("") ps.Go(func() { runner.RunMultiHook(func(hooks plugin.Hooks) bool { @@ -316,14 +318,24 @@ func (wc *WebConn) GetActiveTeamID() string { return wc.activeTeamID.Load().(string) } -// GetActiveThreadChannelID returns the channel id of the active thread of the connection. -func (wc *WebConn) GetActiveThreadChannelID() string { - return wc.activeThreadChannelID.Load().(string) +// GetActiveRHSThreadChannelID returns the channel id of the active thread of the connection. +func (wc *WebConn) GetActiveRHSThreadChannelID() string { + return wc.activeRHSThreadChannelID.Load().(string) } -// SetActiveThreadChannelID sets the channel id of the active thread of the connection. -func (wc *WebConn) SetActiveThreadChannelID(id string) { - wc.activeThreadChannelID.Store(id) +// SetActiveRHSThreadChannelID sets the channel id of the active thread of the connection. +func (wc *WebConn) SetActiveRHSThreadChannelID(id string) { + wc.activeRHSThreadChannelID.Store(id) +} + +// GetActiveThreadViewThreadChannelID returns the channel id of the active thread of the connection. +func (wc *WebConn) GetActiveThreadViewThreadChannelID() string { + return wc.activeThreadViewThreadChannelID.Load().(string) +} + +// SetActiveThreadViewThreadChannelID sets the channel id of the active thread of the connection. +func (wc *WebConn) SetActiveThreadViewThreadChannelID(id string) { + wc.activeThreadViewThreadChannelID.Store(id) } // areAllInactive returns whether all of the connections diff --git a/server/channels/app/platform/websocket_router.go b/server/channels/app/platform/websocket_router.go index c9abfd57b65..2a5482d5305 100644 --- a/server/channels/app/platform/websocket_router.go +++ b/server/channels/app/platform/websocket_router.go @@ -84,7 +84,11 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque } if thChannelID, ok := r.Data["thread_channel_id"].(string); ok { // Set the channelID of the active thread. - conn.SetActiveThreadChannelID(thChannelID) + if isThreadView, ok := r.Data["is_thread_view"].(bool); ok && isThreadView { + conn.SetActiveThreadViewThreadChannelID(thChannelID) + } else { + conn.SetActiveRHSThreadChannelID(thChannelID) + } } resp := model.NewWebSocketResponse(model.StatusOk, r.Seq, nil) diff --git a/server/public/model/websocket_client.go b/server/public/model/websocket_client.go index 8a3319beae7..559218e2a5c 100644 --- a/server/public/model/websocket_client.go +++ b/server/public/model/websocket_client.go @@ -343,9 +343,10 @@ func (wsc *WebSocketClient) UpdateActiveTeam(teamID string) { } // UpdateActiveThread sets the channel id of the current thread that the user is in. -func (wsc *WebSocketClient) UpdateActiveThread(channelID string) { +func (wsc *WebSocketClient) UpdateActiveThread(isThreadView bool, channelID string) { data := map[string]any{ "thread_channel_id": channelID, + "is_thread_view": isThreadView, } wsc.SendMessage(string(WebsocketPresenceIndicator), data) } diff --git a/webapp/channels/src/components/threading/thread_viewer/thread_viewer.test.tsx b/webapp/channels/src/components/threading/thread_viewer/thread_viewer.test.tsx index ff9eb9713a0..b39c9ba970e 100644 --- a/webapp/channels/src/components/threading/thread_viewer/thread_viewer.test.tsx +++ b/webapp/channels/src/components/threading/thread_viewer/thread_viewer.test.tsx @@ -67,6 +67,7 @@ describe('components/threading/ThreadViewer', () => { postIds: [post.id], appsEnabled: true, rootPostId: post.id, + isThreadView: true, }; test('should match snapshot', async () => { diff --git a/webapp/channels/src/components/threading/thread_viewer/thread_viewer.tsx b/webapp/channels/src/components/threading/thread_viewer/thread_viewer.tsx index d427b8e9792..bed15968dd8 100644 --- a/webapp/channels/src/components/threading/thread_viewer/thread_viewer.tsx +++ b/webapp/channels/src/components/threading/thread_viewer/thread_viewer.tsx @@ -49,7 +49,7 @@ export type Props = Attrs & { postIds: string[]; highlightedPostId?: Post['id']; selectedPostFocusedAt?: number; - isThreadView?: boolean; + isThreadView: boolean; inputPlaceholder?: string; rootPostId: string; fromSuppressed?: boolean; @@ -81,7 +81,7 @@ export default class ThreadViewer extends React.PureComponent { } public componentWillUnmount() { - WebSocketClient.updateActiveThread(''); + WebSocketClient.updateActiveThread(this.props.isThreadView, ''); } public componentDidUpdate(prevProps: Props) { @@ -184,7 +184,7 @@ export default class ThreadViewer extends React.PureComponent { } if (this.props.channel) { - WebSocketClient.updateActiveThread(this.props.channel?.id); + WebSocketClient.updateActiveThread(this.props.isThreadView, this.props.channel?.id); } this.setState({isLoading: false}); }; diff --git a/webapp/platform/client/src/websocket.ts b/webapp/platform/client/src/websocket.ts index 0fb567b6e4a..bb8c6ebd8fb 100644 --- a/webapp/platform/client/src/websocket.ts +++ b/webapp/platform/client/src/websocket.ts @@ -396,9 +396,10 @@ export default class WebSocketClient { this.sendMessage('presence', data, callback); } - updateActiveThread(channelId: string, callback?: (msg: any) => void) { + updateActiveThread(isThreadView: boolean, channelId: string, callback?: (msg: any) => void) { const data = { thread_channel_id: channelId, + is_thread_view: isThreadView, }; this.sendMessage('presence', data, callback); }