mattermost/server/channels/app/platform/websocket_reliable.go

163 lines
4.6 KiB
Go
Raw Permalink Normal View History

MM-61904: Make reliable websockets work in HA (#29489) We do a cluster request to get the active and dead queues from other nodes in the cluster to sync any missing information. We check the dead queue in the other nodes to see if there's been any message loss or not. Accordingly, we send just the active queue or both active and dead queues. There's still an edge case that is left out where a client could have potentially connected and reconnected to multiple nodes leaving multiple active queues in multiple nodes. We don't handle this scenario because then potentially we need to create a slice of sendQueueSize * number_of_nodes. And then this can happen again, leading to an infinite increase in sendQueueSize. We leave this edge-case to Redis, acknowledging a limitation in our architecture. In this PR, when there's no message loss, we just take the active queue from the last node it connected to. And if there's message loss where the client's seqNum is within the last node's dead queue, we also handle that. But if there's severe message loss where the client's seqNum falls within the dead queue of another node, then we just send the data from that node to reconstruct the data as much as possible. It could be possible to set a new connection ID in this case, but this involves more data transfer always from all nodes and recomputing the state in the requestor node. https://mattermost.atlassian.net/browse/MM-61904 ```release-note NONE ``` Co-authored-by: Mattermost Build <build@mattermost.com>
2025-01-17 00:41:32 -05:00
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"bytes"
"encoding/json"
"fmt"
"github.com/mattermost/mattermost/server/public/model"
)
func (ps *PlatformService) GetWSQueues(userID, connectionID string, seqNum int64) (*model.WSQueues, error) {
hub := ps.GetHubForUserId(userID)
if hub == nil {
return nil, nil
}
connRes := hub.CheckConn(userID, connectionID)
if connRes == nil {
return nil, nil
}
aq := connRes.ActiveQueue
dq := connRes.DeadQueue
dqPtr := connRes.DeadQueuePointer
// Nothing was written on this server. Early return.
if dq[0] == nil {
return nil, nil
}
// Check if seq_num-1 == last value in the dead queue.
if perfectMatch := !_hasMsgLoss(dq, dqPtr, seqNum); perfectMatch {
close(aq)
aqSlice, err := ps.marshalAQ(aq, connectionID, userID)
if err != nil {
return nil, fmt.Errorf("failed to get from active queue: %w", err)
}
// send only aq
return &model.WSQueues{
ActiveQ: aqSlice,
ReuseCount: connRes.ReuseCount,
}, nil
}
// Check if seq_num is somewhere else in the dead queue.
if ok, index := _isInDeadQueue(dq, seqNum); ok {
close(aq)
aqSlice, err := ps.marshalAQ(aq, connectionID, userID)
if err != nil {
return nil, fmt.Errorf("failed to get from active queue: %w", err)
}
dqSlice, err := ps.marshalDQ(dq, index, dqPtr)
if err != nil {
return nil, fmt.Errorf("failed to get from dead queue: %w", err)
}
// send aq + drainedDq.
return &model.WSQueues{
ActiveQ: aqSlice,
DeadQ: dqSlice,
ReuseCount: connRes.ReuseCount,
}, nil
}
// Nothing matched.
return nil, nil
}
func (ps *PlatformService) marshalAQ(aq <-chan model.WebSocketMessage, connID, userID string) ([]model.ActiveQueueItem, error) {
aqSlice := make([]model.ActiveQueueItem, 0)
for msg := range aq {
evtType := model.WebSocketMsgTypeResponse
_, evtOk := msg.(*model.WebSocketEvent)
if evtOk {
evtType = model.WebSocketMsgTypeEvent
}
buf, err := msg.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to marshal websocket event: %w, connection_id=%s, user_id=%s", err, connID, userID)
}
aqSlice = append(aqSlice, model.ActiveQueueItem{
Buf: json.RawMessage(buf),
Type: evtType,
})
}
return aqSlice, nil
}
func (ps *PlatformService) UnmarshalAQItem(aqItem model.ActiveQueueItem) (model.WebSocketMessage, error) {
var item model.WebSocketMessage
var err error
if aqItem.Type == model.WebSocketMsgTypeEvent {
item, err = model.WebSocketEventFromJSON(bytes.NewReader(aqItem.Buf))
} else if aqItem.Type == model.WebSocketMsgTypeResponse {
item, err = model.WebSocketResponseFromJSON(bytes.NewReader(aqItem.Buf))
} else {
return nil, fmt.Errorf("unknown websocket message type: %q", aqItem.Type)
}
return item, err
}
// marshalDQ is the same as drainDeadQueue, except it writes to a byte slice
// instead of the network. To be refactored into a single method.
func (ps *PlatformService) marshalDQ(dq []*model.WebSocketEvent, index, dqPtr int) ([]json.RawMessage, error) {
if len(dq) == 0 || dq[0] == nil {
return nil, nil
}
dqSlice := make([]json.RawMessage, 0)
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
// This means pointer hasn't rolled over.
if dq[dqPtr] == nil {
// Clear till the end of queue.
for i := index; i < dqPtr; i++ {
buf.Reset()
err := dq[i].Encode(enc, &buf)
if err != nil {
return nil, fmt.Errorf("error in encoding websocket message in dead queue: %w", err)
}
dqSlice = append(dqSlice, bytes.Clone(buf.Bytes()))
}
return dqSlice, nil
}
// We go on until next sequence number is smaller than previous one.
// Which means it has rolled over.
currPtr := index
for {
buf.Reset()
err := dq[currPtr].Encode(enc, &buf)
if err != nil {
return nil, fmt.Errorf("error in encoding websocket message in dead queue: %w", err)
}
dqSlice = append(dqSlice, bytes.Clone(buf.Bytes()))
oldSeq := dq[currPtr].GetSequence()
currPtr = (currPtr + 1) % deadQueueSize
newSeq := dq[currPtr].GetSequence()
if oldSeq > newSeq {
break
}
}
return dqSlice, nil
}
func (ps *PlatformService) UnmarshalDQ(buf []json.RawMessage) ([]*model.WebSocketEvent, int, error) {
dqPtr := 0
dq := make([]*model.WebSocketEvent, deadQueueSize)
for _, dqItem := range buf {
item, err := model.WebSocketEventFromJSON(bytes.NewReader(dqItem))
if err != nil {
return nil, 0, err
}
// Same as active queue, this can never be out of bounds because all dead queues
// are of deadQueueSize.
dq[dqPtr] = item
dqPtr = (dqPtr + 1) % deadQueueSize
MM-61904: Make reliable websockets work in HA (#29489) We do a cluster request to get the active and dead queues from other nodes in the cluster to sync any missing information. We check the dead queue in the other nodes to see if there's been any message loss or not. Accordingly, we send just the active queue or both active and dead queues. There's still an edge case that is left out where a client could have potentially connected and reconnected to multiple nodes leaving multiple active queues in multiple nodes. We don't handle this scenario because then potentially we need to create a slice of sendQueueSize * number_of_nodes. And then this can happen again, leading to an infinite increase in sendQueueSize. We leave this edge-case to Redis, acknowledging a limitation in our architecture. In this PR, when there's no message loss, we just take the active queue from the last node it connected to. And if there's message loss where the client's seqNum is within the last node's dead queue, we also handle that. But if there's severe message loss where the client's seqNum falls within the dead queue of another node, then we just send the data from that node to reconstruct the data as much as possible. It could be possible to set a new connection ID in this case, but this involves more data transfer always from all nodes and recomputing the state in the requestor node. https://mattermost.atlassian.net/browse/MM-61904 ```release-note NONE ``` Co-authored-by: Mattermost Build <build@mattermost.com>
2025-01-17 00:41:32 -05:00
}
return dq, dqPtr, nil
}