MM-61904: Make reliable websockets work in HA (#29489)

We do a cluster request to get the active and dead queues
from other nodes in the cluster to sync any missing
information.

We check the dead queue in the other nodes to see
if there's been any message loss or not. Accordingly,
we send just the active queue or both active and dead queues.

There's still an edge case that is left out where
a client could have potentially connected and reconnected
to multiple nodes leaving multiple active queues
in multiple nodes. We don't handle this scenario
because then potentially we need to create
a slice of sendQueueSize * number_of_nodes. And then
this can happen again, leading to an infinite increase
in sendQueueSize.

We leave this edge-case to Redis, acknowledging
a limitation in our architecture.

In this PR, when there's no message loss, we just
take the active queue from the last node it connected
to.

And if there's message loss where the client's
seqNum is within the last node's dead queue, we also
handle that.

But if there's severe message loss where the client's
seqNum falls within the dead queue of another node, then
we just send the data from that node to reconstruct the
data as much as possible. It could be possible to set
a new connection ID in this case, but this involves
more data transfer always from all nodes and recomputing
the state in the requestor node.

https://mattermost.atlassian.net/browse/MM-61904

```release-note
NONE
```

Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Agniva De Sarker 2025-01-17 11:11:32 +05:30 committed by GitHub
parent 921604cf39
commit cb75a20c54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 404 additions and 46 deletions

View file

@ -486,7 +486,6 @@ type AppIface interface {
CheckUserMfa(rctx request.CTX, user *model.User, token string) *model.AppError
CheckUserPostflightAuthenticationCriteria(rctx request.CTX, user *model.User) *model.AppError
CheckUserPreflightAuthenticationCriteria(rctx request.CTX, user *model.User, mfaToken string) *model.AppError
CheckWebConn(userID, connectionID string) *platform.CheckConnResult
CleanUpAfterPostDeletion(c request.CTX, post *model.Post, deleteByID string) *model.AppError
CleanupReportChunks(format string, prefix string, numberOfChunks int) *model.AppError
ClearChannelMembersCache(c request.CTX, channelID string) error

View file

@ -162,3 +162,6 @@ func (c *ClusterMock) HealthScore() int { return 0 }
func (c *ClusterMock) WebConnCountForUser(userID string) (int, *model.AppError) {
return 0, nil
}
func (c *ClusterMock) GetWSQueues(userID, connectionID string, seqNum int64) (map[string]*model.WSQueues, error) {
return nil, nil
}

View file

@ -1438,23 +1438,6 @@ func (a *OpenTracingAppLayer) CheckUserPreflightAuthenticationCriteria(rctx requ
return resultVar0
}
func (a *OpenTracingAppLayer) CheckWebConn(userID string, connectionID string) *platform.CheckConnResult {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CheckWebConn")
a.ctx = newCtx
a.app.Srv().Store().SetContext(newCtx)
defer func() {
a.app.Srv().Store().SetContext(origCtx)
a.ctx = origCtx
}()
defer span.Finish()
resultVar0 := a.app.CheckWebConn(userID, connectionID)
return resultVar0
}
func (a *OpenTracingAppLayer) CleanUpAfterPostDeletion(c request.CTX, post *model.Post, deleteByID string) *model.AppError {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CleanUpAfterPostDeletion")

View file

@ -164,3 +164,6 @@ func (c *ClusterMock) HealthScore() int { return 0 }
func (c *ClusterMock) WebConnCountForUser(userID string) (int, *model.AppError) {
return 0, nil
}
func (c *ClusterMock) GetWSQueues(userID, connectionID string, seqNum int64) (map[string]*model.WSQueues, error) {
return nil, nil
}

View file

@ -74,7 +74,7 @@ type WebConnConfig struct {
XForwardedFor string
// These aren't necessary to be exported to api layer.
sequence int
sequence int64
activeQueue chan model.WebSocketMessage
deadQueue []*model.WebSocketEvent
deadQueuePointer int
@ -161,9 +161,20 @@ func (ps *PlatformService) PopulateWebConnConfig(s *model.Session, cfg *WebConnC
return nil, fmt.Errorf("invalid connection id: %s", cfg.ConnectionID)
}
// Sequence_number must be sent with connection id.
// A client must be either non-compliant or fully compliant.
if seqVal == "" {
return nil, errors.New("sequence number not present in websocket request")
}
seqNum, err := strconv.ParseInt(seqVal, 10, 0)
if err != nil {
return nil, fmt.Errorf("invalid sequence number %s in query param: %w", seqVal, err)
}
// This does not handle reconnect requests across nodes in a cluster.
// It falls back to the non-reliable case in that scenario.
res := ps.CheckWebConn(s.UserId, cfg.ConnectionID)
res := ps.CheckWebConn(s.UserId, cfg.ConnectionID, seqNum)
if res == nil {
// If the connection is not present, then we assume either timeout,
// or server restart. In that case, we set a new one.
@ -175,17 +186,7 @@ func (ps *PlatformService) PopulateWebConnConfig(s *model.Session, cfg *WebConnC
cfg.deadQueuePointer = res.DeadQueuePointer
cfg.Active = false
cfg.ReuseCount = res.ReuseCount
// Now we get the sequence number
if seqVal == "" {
// Sequence_number must be sent with connection id.
// A client must be either non-compliant or fully compliant.
return nil, errors.New("sequence number not present in websocket request")
}
var err error
cfg.sequence, err = strconv.Atoi(seqVal)
if err != nil || cfg.sequence < 0 {
return nil, fmt.Errorf("invalid sequence number %s in query param: %v", seqVal, err)
}
cfg.sequence = seqNum
}
return cfg, nil
}
@ -235,7 +236,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
send: cfg.activeQueue,
deadQueue: cfg.deadQueue,
deadQueuePointer: cfg.deadQueuePointer,
Sequence: int64(cfg.sequence),
Sequence: cfg.sequence,
WebSocket: cfg.WebSocket,
lastUserActivityAt: model.GetMillis(),
UserId: cfg.Session.UserId,
@ -655,34 +656,48 @@ func (wc *WebConn) addToDeadQueue(msg *model.WebSocketEvent) {
// hasMsgLoss indicates whether the next wanted sequence is right after
// the latest element in the dead queue, which would mean there is no message loss.
func (wc *WebConn) hasMsgLoss() bool {
return _hasMsgLoss(wc.deadQueue, wc.deadQueuePointer, wc.Sequence)
}
// isInDeadQueue checks whether a given sequence number is in the dead queue or not.
// And if it is, it returns that index.
func (wc *WebConn) isInDeadQueue(seq int64) (bool, int) {
return _isInDeadQueue(wc.deadQueue, seq)
}
// _hasMsgLoss is called from 2 places: wc.hasMsgLoss and ps.GetWSQueues.
// It is done this way because it is difficult to call wc.hasMsgLoss from inside
// ps.GetWSQueues
func _hasMsgLoss(deadQueue []*model.WebSocketEvent, deadQueuePtr int, seq int64) bool {
var index int
// deadQueuePointer = 0 means either no msg written or the pointer
// has rolled over to its starting position.
if wc.deadQueuePointer == 0 {
// If last entry is nil, it means no msg is written.
if wc.deadQueue[deadQueueSize-1] == nil {
if deadQueuePtr == 0 {
// If first entry is nil, it means no msg is written.
if deadQueue[0] == nil {
return false
}
// If it's not nil, that means it has rolled over to start, and we
// check the last position.
index = deadQueueSize - 1
} else { // deadQueuePointer != 0 means it's somewhere in the middle.
index = wc.deadQueuePointer - 1
index = deadQueuePtr - 1
}
if wc.deadQueue[index].GetSequence() == wc.Sequence-1 {
if deadQueue[index].GetSequence() == seq-1 {
return false
}
return true
}
// isInDeadQueue checks whether a given sequence number is in the dead queue or not.
// And if it is, it returns that index.
func (wc *WebConn) isInDeadQueue(seq int64) (bool, int) {
// _isInDeadQueue is called from 2 places: wc.isInDeadQueue and ps.GetWSQueues.
// It is done this way because it is difficult to call wc.isInDeadQueue from inside
// ps.GetWSQueues
func _isInDeadQueue(deadQueue []*model.WebSocketEvent, seq int64) (bool, int) {
// Can be optimized to traverse backwards from deadQueuePointer
// Hopefully, traversing 128 elements is not too much overhead.
for i := 0; i < deadQueueSize; i++ {
elem := wc.deadQueue[i]
elem := deadQueue[i]
if elem == nil {
return false, 0
}

View file

@ -254,7 +254,84 @@ func (ps *PlatformService) SessionIsRegistered(session model.Session) bool {
return false
}
func (ps *PlatformService) CheckWebConn(userID, connectionID string) *CheckConnResult {
func (ps *PlatformService) CheckWebConn(userID, connectionID string, seqNum int64) *CheckConnResult {
if ps.Cluster() == nil || seqNum == 0 {
hub := ps.GetHubForUserId(userID)
if hub != nil {
return hub.CheckConn(userID, connectionID)
}
return nil
}
// We need some extra care for HA
// Check other nodes
// If any nodes return with an aq and/or dq, use that.
// If all nodes return empty, proceed with local case.
// We have to do this because a client might reconnect with an older seq num to a node
// which it had connected before. So checking its local queue will lead the server to believe
// that there is no msg loss, whereas there is actually loss.
queueMap, err := ps.Cluster().GetWSQueues(userID, connectionID, seqNum)
if err != nil {
// If there is an error we do not have enough data to say anything reliably.
// Fall back to unreliable case.
ps.Log().Error("Error while getting websocket queues",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Int("sequence_number", seqNum),
mlog.Err(err))
return nil
}
connRes := &CheckConnResult{
ConnectionID: connectionID,
UserID: userID,
}
for _, queues := range queueMap {
if queues == nil || queues.ActiveQ == nil {
continue
}
// parse the activeq
aq := make(chan model.WebSocketMessage, sendQueueSize)
for _, aqItem := range queues.ActiveQ {
item, err := ps.UnmarshalAQItem(aqItem)
if err != nil {
ps.Log().Error("Error while unmarshalling websocket message from active queue",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Err(err))
return nil
}
// This cannot block because all send queues are of sendQueueSize at max.
// TODO: There could be a case where there's severe message loss, and to
// reliably get the messages, we need to get send queues from multiple nodes.
// We leave that case for Redis.
aq <- item
}
connRes.ActiveQueue = aq
connRes.ReuseCount = queues.ReuseCount
// parse the dq, wc.addToDeadQ()
if queues.DeadQ != nil {
dq, dqPtr, err := ps.UnmarshalDQ(queues.DeadQ)
if err != nil {
ps.Log().Error("Error while unmarshalling websocket message from dead queue",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Err(err))
return nil
}
if dqPtr > 0 {
connRes.DeadQueue = dq
connRes.DeadQueuePointer = dqPtr
}
}
return connRes
}
// Now we check local queue
hub := ps.GetHubForUserId(userID)
if hub != nil {
return hub.CheckConn(userID, connectionID)

View file

@ -0,0 +1,162 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"bytes"
"encoding/json"
"fmt"
"github.com/mattermost/mattermost/server/public/model"
)
func (ps *PlatformService) GetWSQueues(userID, connectionID string, seqNum int64) (*model.WSQueues, error) {
hub := ps.GetHubForUserId(userID)
if hub == nil {
return nil, nil
}
connRes := hub.CheckConn(userID, connectionID)
if connRes == nil {
return nil, nil
}
aq := connRes.ActiveQueue
dq := connRes.DeadQueue
dqPtr := connRes.DeadQueuePointer
// Nothing was written on this server. Early return.
if dq[0] == nil {
return nil, nil
}
// Check if seq_num-1 == last value in the dead queue.
if perfectMatch := !_hasMsgLoss(dq, dqPtr, seqNum); perfectMatch {
close(aq)
aqSlice, err := ps.marshalAQ(aq, connectionID, userID)
if err != nil {
return nil, fmt.Errorf("failed to get from active queue: %w", err)
}
// send only aq
return &model.WSQueues{
ActiveQ: aqSlice,
ReuseCount: connRes.ReuseCount,
}, nil
}
// Check if seq_num is somewhere else in the dead queue.
if ok, index := _isInDeadQueue(dq, seqNum); ok {
close(aq)
aqSlice, err := ps.marshalAQ(aq, connectionID, userID)
if err != nil {
return nil, fmt.Errorf("failed to get from active queue: %w", err)
}
dqSlice, err := ps.marshalDQ(dq, index, dqPtr)
if err != nil {
return nil, fmt.Errorf("failed to get from dead queue: %w", err)
}
// send aq + drainedDq.
return &model.WSQueues{
ActiveQ: aqSlice,
DeadQ: dqSlice,
ReuseCount: connRes.ReuseCount,
}, nil
}
// Nothing matched.
return nil, nil
}
func (ps *PlatformService) marshalAQ(aq <-chan model.WebSocketMessage, connID, userID string) ([]model.ActiveQueueItem, error) {
aqSlice := make([]model.ActiveQueueItem, 0)
for msg := range aq {
evtType := model.WebSocketMsgTypeResponse
_, evtOk := msg.(*model.WebSocketEvent)
if evtOk {
evtType = model.WebSocketMsgTypeEvent
}
buf, err := msg.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to marshal websocket event: %w, connection_id=%s, user_id=%s", err, connID, userID)
}
aqSlice = append(aqSlice, model.ActiveQueueItem{
Buf: json.RawMessage(buf),
Type: evtType,
})
}
return aqSlice, nil
}
func (ps *PlatformService) UnmarshalAQItem(aqItem model.ActiveQueueItem) (model.WebSocketMessage, error) {
var item model.WebSocketMessage
var err error
if aqItem.Type == model.WebSocketMsgTypeEvent {
item, err = model.WebSocketEventFromJSON(bytes.NewReader(aqItem.Buf))
} else if aqItem.Type == model.WebSocketMsgTypeResponse {
item, err = model.WebSocketResponseFromJSON(bytes.NewReader(aqItem.Buf))
} else {
return nil, fmt.Errorf("unknown websocket message type: %q", aqItem.Type)
}
return item, err
}
// marshalDQ is the same as drainDeadQueue, except it writes to a byte slice
// instead of the network. To be refactored into a single method.
func (ps *PlatformService) marshalDQ(dq []*model.WebSocketEvent, index, dqPtr int) ([]json.RawMessage, error) {
if len(dq) == 0 || dq[0] == nil {
return nil, nil
}
dqSlice := make([]json.RawMessage, 0)
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
// This means pointer hasn't rolled over.
if dq[dqPtr] == nil {
// Clear till the end of queue.
for i := index; i < dqPtr; i++ {
buf.Reset()
err := dq[i].Encode(enc, &buf)
if err != nil {
return nil, fmt.Errorf("error in encoding websocket message in dead queue: %w", err)
}
dqSlice = append(dqSlice, bytes.Clone(buf.Bytes()))
}
return dqSlice, nil
}
// We go on until next sequence number is smaller than previous one.
// Which means it has rolled over.
currPtr := index
for {
buf.Reset()
err := dq[currPtr].Encode(enc, &buf)
if err != nil {
return nil, fmt.Errorf("error in encoding websocket message in dead queue: %w", err)
}
dqSlice = append(dqSlice, bytes.Clone(buf.Bytes()))
oldSeq := dq[currPtr].GetSequence()
currPtr = (currPtr + 1) % deadQueueSize
newSeq := dq[currPtr].GetSequence()
if oldSeq > newSeq {
break
}
}
return dqSlice, nil
}
func (ps *PlatformService) UnmarshalDQ(buf []json.RawMessage) ([]*model.WebSocketEvent, int, error) {
dqPtr := 0
dq := make([]*model.WebSocketEvent, deadQueueSize)
for _, dqItem := range buf {
item, err := model.WebSocketEventFromJSON(bytes.NewReader(dqItem))
if err != nil {
return nil, 0, err
}
// Same as active queue, this can never be out of bounds because all dead queues
// are of deadQueueSize.
dq[dqPtr] = item
dqPtr++
}
return dq, dqPtr, nil
}

View file

@ -0,0 +1,67 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMarshalAQ(t *testing.T) {
ps := PlatformService{}
events := []model.WebSocketMessage{
model.NewWebSocketEvent(model.WebsocketEventPosted, "t1", "c1", "u1", nil, ""),
model.NewWebSocketEvent(model.WebsocketEventReactionAdded, "t2", "c1", "u1", nil, ""),
model.NewWebSocketEvent(model.WebsocketEventReactionRemoved, "t3", "c1", "u1", nil, ""),
model.NewWebSocketResponse("hi", 10, nil),
}
aq := make(chan model.WebSocketMessage, 10)
for _, ev := range events {
aq <- ev
}
close(aq)
queue, err := ps.marshalAQ(aq, "connID", "u1")
require.NoError(t, err)
assert.Len(t, queue, 4)
var gotEvents []model.WebSocketMessage
for _, item := range queue {
msg, err := ps.UnmarshalAQItem(item)
require.NoError(t, err)
gotEvents = append(gotEvents, msg)
}
assert.Equal(t, events, gotEvents)
}
func TestMarshalDQ(t *testing.T) {
ps := PlatformService{}
// Nothing in case of dead queue is empty
got, err := ps.marshalDQ([]*model.WebSocketEvent{}, 0, 0)
require.NoError(t, err)
require.Nil(t, got)
events := []*model.WebSocketEvent{
model.NewWebSocketEvent(model.WebsocketEventPosted, "t1", "c1", "u1", nil, ""),
model.NewWebSocketEvent(model.WebsocketEventReactionAdded, "t2", "c1", "u1", nil, "").SetSequence(1),
model.NewWebSocketEvent(model.WebsocketEventReactionRemoved, "t3", "c1", "u1", nil, "").SetSequence(2),
nil,
nil,
}
got, err = ps.marshalDQ(events, 0, 3)
require.NoError(t, err)
require.Len(t, got, 3)
gotEvents, dqPtr, err := ps.UnmarshalDQ(got)
require.NoError(t, err)
assert.Equal(t, 3, dqPtr)
assert.Equal(t, events[:3], gotEvents[:3])
}

View file

@ -53,7 +53,3 @@ func (a *App) UpdateWebConnUserActivity(session model.Session, activityAt int64)
func (a *App) SessionIsRegistered(session model.Session) bool {
return a.Srv().Platform().SessionIsRegistered(session)
}
func (a *App) CheckWebConn(userID, connectionID string) *platform.CheckConnResult {
return a.Srv().Platform().CheckWebConn(userID, connectionID)
}

View file

@ -112,3 +112,7 @@ func (c *FakeClusterInterface) ClearMessages() {
func (c *FakeClusterInterface) WebConnCountForUser(userID string) (int, *model.AppError) {
return 0, nil
}
func (c *FakeClusterInterface) GetWSQueues(userID, connectionID string, seqNum int64) (map[string]*model.WSQueues, error) {
return nil, nil
}

View file

@ -34,4 +34,7 @@ type ClusterInterface interface {
// WebConnCountForUser returns the number of active webconn connections
// for a given userID.
WebConnCountForUser(userID string) (int, *model.AppError)
// GetWSQueues returns the necessary websocket queues from a cluster for a given
// connectionID and sequence number.
GetWSQueues(userID, connectionID string, seqNum int64) (map[string]*model.WSQueues, error)
}

View file

@ -222,6 +222,36 @@ func (_m *ClusterInterface) GetPluginStatuses() (model.PluginStatuses, *model.Ap
return r0, r1
}
// GetWSQueues provides a mock function with given fields: userID, connectionID, seqNum
func (_m *ClusterInterface) GetWSQueues(userID string, connectionID string, seqNum int64) (map[string]*model.WSQueues, error) {
ret := _m.Called(userID, connectionID, seqNum)
if len(ret) == 0 {
panic("no return value specified for GetWSQueues")
}
var r0 map[string]*model.WSQueues
var r1 error
if rf, ok := ret.Get(0).(func(string, string, int64) (map[string]*model.WSQueues, error)); ok {
return rf(userID, connectionID, seqNum)
}
if rf, ok := ret.Get(0).(func(string, string, int64) map[string]*model.WSQueues); ok {
r0 = rf(userID, connectionID, seqNum)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]*model.WSQueues)
}
}
if rf, ok := ret.Get(1).(func(string, string, int64) error); ok {
r1 = rf(userID, connectionID, seqNum)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HealthScore provides a mock function with given fields:
func (_m *ClusterInterface) HealthScore() int {
ret := _m.Called()

View file

@ -58,6 +58,8 @@ const (
ClusterGossipEventResponseSaveConfig = "gossip_response_save_config"
ClusterGossipEventRequestWebConnCount = "gossip_request_webconn_count"
ClusterGossipEventResponseWebConnCount = "gossip_response_webconn_count"
ClusterGossipEventRequestWSQueues = "gossip_request_ws_queues"
ClusterGossipEventResponseWSQueues = "gossip_response_ws_queues"
// SendTypes for ClusterMessage.
ClusterSendBestEffort = "best_effort"

View file

@ -94,8 +94,22 @@ const (
WebsocketScheduledPostCreated WebsocketEventType = "scheduled_post_created"
WebsocketScheduledPostUpdated WebsocketEventType = "scheduled_post_updated"
WebsocketScheduledPostDeleted WebsocketEventType = "scheduled_post_deleted"
WebSocketMsgTypeResponse = "response"
WebSocketMsgTypeEvent = "event"
)
type ActiveQueueItem struct {
Type string `json:"type"` // websocket event or websocket response
Buf json.RawMessage `json:"buf"`
}
type WSQueues struct {
ActiveQ []ActiveQueueItem `json:"active_queue"` // websocketEvent|websocketResponse
DeadQ []json.RawMessage `json:"dead_queue"` // websocketEvent
ReuseCount int `json:"reuse_count"`
}
type WebSocketMessage interface {
ToJSON() ([]byte, error)
IsValid() bool