diff --git a/.gitignore b/.gitignore index 9f8cad86692..81c735fd371 100644 --- a/.gitignore +++ b/.gitignore @@ -18,11 +18,16 @@ web/static/js/libs*.js config/active.dat config/config.json +config/logging.json /plugins # Enterprise imports file imports/imports.go +#license files +*.license +*.mattermost-license + # Build Targets .prebuild .npminstall diff --git a/Makefile b/Makefile index af2f50af9fd..094eec20907 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build package run stop run-client run-server run-haserver stop-client stop-server restart restart-server restart-client restart-haserver start-docker clean-dist clean nuke check-style check-client-style check-server-style check-unit-tests test dist prepare-enteprise run-client-tests setup-run-client-tests cleanup-run-client-tests test-client build-linux build-osx build-windows internal-test-web-client vet run-server-for-web-client-tests diff-config prepackaged-plugins prepackaged-binaries test-server test-server-ee test-server-quick test-server-race start-docker-check migrations-bindata new-migration migration-prereqs +.PHONY: build package run stop run-client run-server run-haserver stop-haserver stop-client stop-server restart restart-server restart-client restart-haserver start-docker clean-dist clean nuke check-style check-client-style check-server-style check-unit-tests test dist prepare-enteprise run-client-tests setup-run-client-tests cleanup-run-client-tests test-client build-linux build-osx build-windows internal-test-web-client vet run-server-for-web-client-tests diff-config prepackaged-plugins prepackaged-binaries test-server test-server-ee test-server-quick test-server-race start-docker-check migrations-bindata new-migration migration-prereqs ROOT := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) @@ -58,6 +58,15 @@ else BUILD_CLIENT = false endif +# We need current user's UID for `run-haserver` so docker compose does not run server +# as root and mess up file permissions for devs. When running like this HOME will be blank +# and docker will add '/', so we need to set the go-build cache location or we'll get +# permission errors on build as it tries to create a cache in filesystem root. +export CURRENT_UID = $(shell id -u):$(shell id -g) +ifeq ($(HOME),/) + export XDG_CACHE_HOME = /tmp/go-cache/ +endif + # Go Flags GOFLAGS ?= $(GOFLAGS:) # We need to export GOBIN to allow it to be set @@ -170,13 +179,17 @@ ifneq (,$(findstring mysql-read-replica,$(ENABLED_DOCKER_SERVICES))) endif endif -run-haserver: run-client +run-haserver: ifeq ($(BUILD_ENTERPRISE_READY),true) - @echo Starting mattermost in an HA topology + @echo Starting mattermost in an HA topology '(3 node cluster)' - docker-compose -f docker-compose.yaml up haproxy + docker-compose -f docker-compose.yaml up --remove-orphans haproxy endif +stop-haserver: + @echo Stopping docker containers for HA topology + docker-compose stop + stop-docker: ## Stops the docker containers for local development. ifeq ($(MM_NO_DOCKER),true) @echo No Docker Enabled: skipping docker stop @@ -294,6 +307,11 @@ searchengine-mocks: ## Creates mock files for searchengines. $(GO) get -modfile=go.tools.mod github.com/vektra/mockery/... $(GOBIN)/mockery -dir services/searchengine -all -output services/searchengine/mocks -note 'Regenerate this file using `make searchengine-mocks`.' +sharedchannel-mocks: ## Creates mock files for shared channels. + $(GO) get -modfile=go.tools.mod github.com/vektra/mockery/... + $(GOBIN)/mockery -dir=./services/sharedchannel -name=ServerIface -output=./services/sharedchannel -inpkg -outpkg=sharedchannel -testonly -note 'Regenerate this file using `make sharedchannel-mocks`.' + $(GOBIN)/mockery -dir=./services/sharedchannel -name=AppIface -output=./services/sharedchannel -inpkg -outpkg=sharedchannel -testonly -note 'Regenerate this file using `make sharedchannel-mocks`.' + pluginapi: ## Generates api and hooks glue code for plugins $(GO) generate $(GOFLAGS) ./plugin @@ -497,6 +515,7 @@ restart-server: | stop-server run-server ## Restarts the mattermost server to pi restart-haserver: @echo Restarting mattermost in an HA topology + docker-compose restart follower2 docker-compose restart follower docker-compose restart leader docker-compose restart haproxy diff --git a/api4/api.go b/api4/api.go index 0a01d7e03e4..46b57ae7031 100644 --- a/api4/api.go +++ b/api4/api.go @@ -125,8 +125,12 @@ type Routes struct { Cloud *mux.Router // 'api/v4/cloud' Imports *mux.Router // 'api/v4/imports' + Exports *mux.Router // 'api/v4/exports' Export *mux.Router // 'api/v4/exports/{export_name:.+\\.zip}' + + RemoteCluster *mux.Router // 'api/v4/remotecluster' + SharedChannels *mux.Router // 'api/v4/sharedchannels' } type API struct { @@ -243,6 +247,9 @@ func Init(configservice configservice.ConfigService, globalOptionsFunc app.AppOp api.BaseRoutes.Exports = api.BaseRoutes.ApiRoot.PathPrefix("/exports").Subrouter() api.BaseRoutes.Export = api.BaseRoutes.Exports.PathPrefix("/{export_name:.+\\.zip}").Subrouter() + api.BaseRoutes.RemoteCluster = api.BaseRoutes.ApiRoot.PathPrefix("/remotecluster").Subrouter() + api.BaseRoutes.SharedChannels = api.BaseRoutes.ApiRoot.PathPrefix("/sharedchannels").Subrouter() + api.InitUser() api.InitBot() api.InitTeam() @@ -280,6 +287,8 @@ func Init(configservice configservice.ConfigService, globalOptionsFunc app.AppOp api.InitAction() api.InitCloud() api.InitImport() + api.InitRemoteCluster() + api.InitSharedChannels() api.InitExport() root.Handle("/api/v4/{anything:.*}", http.HandlerFunc(api.Handle404)) diff --git a/api4/handlers.go b/api4/handlers.go index aa6fa996663..255837afad4 100644 --- a/api4/handlers.go +++ b/api4/handlers.go @@ -72,6 +72,26 @@ func (api *API) CloudApiKeyRequired(h func(*Context, http.ResponseWriter, *http. } +// RemoteClusterTokenRequired provides a handler for remote cluster requests to /remotecluster endpoints. +func (api *API) RemoteClusterTokenRequired(h func(*Context, http.ResponseWriter, *http.Request)) http.Handler { + handler := &web.Handler{ + GetGlobalAppOptions: api.GetGlobalAppOptions, + HandleFunc: h, + HandlerName: web.GetHandlerName(h), + RequireSession: false, + RequireCloudKey: false, + RequireRemoteClusterToken: true, + TrustRequester: false, + RequireMfa: false, + IsStatic: false, + IsLocal: false, + } + if *api.ConfigService.Config().ServiceSettings.WebserverMode == "gzip" { + return gziphandler.GzipHandler(handler) + } + return handler +} + // ApiSessionRequiredMfa provides a handler for API endpoints which require a logged-in user session but when accessed, // if MFA is enabled, the MFA process is not yet complete, and therefore the requirement to have completed the MFA // authentication must be waived. diff --git a/api4/post_test.go b/api4/post_test.go index 98c272e4478..e045b3207d6 100644 --- a/api4/post_test.go +++ b/api4/post_test.go @@ -618,7 +618,7 @@ func TestCreatePostCheckOnlineStatus(t *testing.T) { } case <-timeout: // We just skip the test instead of failing because waiting for more than 5 seconds - // to get a response does not make sense, and it will unncessarily slow down + // to get a response does not make sense, and it will unnecessarily slow down // the tests further in an already congested CI environment. t.Skip("timed out waiting for event") } @@ -2035,7 +2035,7 @@ func TestDeletePostMessage(t *testing.T) { } case <-timeout: // We just skip the test instead of failing because waiting for more than 5 seconds - // to get a response does not make sense, and it will unncessarily slow down + // to get a response does not make sense, and it will unnecessarily slow down // the tests further in an already congested CI environment. t.Skip("timed out waiting for event") } diff --git a/api4/remote_cluster.go b/api4/remote_cluster.go new file mode 100644 index 00000000000..6add7d8431f --- /dev/null +++ b/api4/remote_cluster.go @@ -0,0 +1,214 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package api4 + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/mattermost/mattermost-server/v5/audit" + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" +) + +func (api *API) InitRemoteCluster() { + api.BaseRoutes.RemoteCluster.Handle("/ping", api.RemoteClusterTokenRequired(remoteClusterPing)).Methods("POST") + api.BaseRoutes.RemoteCluster.Handle("/msg", api.RemoteClusterTokenRequired(remoteClusterAcceptMessage)).Methods("POST") + api.BaseRoutes.RemoteCluster.Handle("/confirm_invite", api.RemoteClusterTokenRequired(remoteClusterConfirmInvite)).Methods("POST") + api.BaseRoutes.RemoteCluster.Handle("/upload/{upload_id:[A-Za-z0-9]+}", api.RemoteClusterTokenRequired(uploadRemoteData)).Methods("POST") +} + +func remoteClusterPing(c *Context, w http.ResponseWriter, r *http.Request) { + // make sure remote cluster service is enabled. + if _, appErr := c.App.GetRemoteClusterService(); appErr != nil { + c.Err = appErr + return + } + + frame, appErr := model.RemoteClusterFrameFromJSON(r.Body) + if appErr != nil { + c.Err = appErr + return + } + + if appErr = frame.IsValid(); appErr != nil { + c.Err = appErr + return + } + + remoteId := c.GetRemoteID(r) + if remoteId != frame.RemoteId { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + + rc, err := c.App.GetRemoteCluster(frame.RemoteId) + if err != nil { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + + ping, err := model.RemoteClusterPingFromRawJSON(frame.Msg.Payload) + if err != nil { + c.SetInvalidParam("msg.payload") + return + } + ping.RecvAt = model.GetMillis() + + if metrics := c.App.Metrics(); metrics != nil { + metrics.IncrementRemoteClusterMsgReceivedCounter(rc.RemoteId) + } + + resp, _ := json.Marshal(ping) + w.Write(resp) +} + +func remoteClusterAcceptMessage(c *Context, w http.ResponseWriter, r *http.Request) { + // make sure remote cluster service is running. + service, appErr := c.App.GetRemoteClusterService() + if appErr != nil { + c.Err = appErr + return + } + + frame, appErr := model.RemoteClusterFrameFromJSON(r.Body) + if appErr != nil { + c.Err = appErr + return + } + + if appErr = frame.IsValid(); appErr != nil { + c.Err = appErr + return + } + + auditRec := c.MakeAuditRecord("remoteClusterAcceptMessage", audit.Fail) + defer c.LogAuditRec(auditRec) + + remoteId := c.GetRemoteID(r) + if remoteId != frame.RemoteId { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + + rc, err := c.App.GetRemoteCluster(frame.RemoteId) + if err != nil { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + auditRec.AddMeta("remoteCluster", rc) + + // pass message to Remote Cluster Service and write response + resp := service.ReceiveIncomingMsg(rc, frame.Msg) + + b, errMarshall := json.Marshal(resp) + if errMarshall != nil { + c.Err = model.NewAppError("remoteClusterAcceptMessage", "api.marshal_error", nil, errMarshall.Error(), http.StatusInternalServerError) + return + } + w.Write(b) +} + +func remoteClusterConfirmInvite(c *Context, w http.ResponseWriter, r *http.Request) { + // make sure remote cluster service is running. + if _, appErr := c.App.GetRemoteClusterService(); appErr != nil { + c.Err = appErr + return + } + + frame, appErr := model.RemoteClusterFrameFromJSON(r.Body) + if appErr != nil { + c.Err = appErr + return + } + + if appErr = frame.IsValid(); appErr != nil { + c.Err = appErr + return + } + + auditRec := c.MakeAuditRecord("remoteClusterAcceptInvite", audit.Fail) + defer c.LogAuditRec(auditRec) + + remoteId := c.GetRemoteID(r) + if remoteId != frame.RemoteId { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + + rc, err := c.App.GetRemoteCluster(frame.RemoteId) + if err != nil { + c.SetInvalidRemoteIdError(frame.RemoteId) + return + } + auditRec.AddMeta("remoteCluster", rc) + + if time.Since(model.GetTimeForMillis(rc.CreateAt)) > remotecluster.InviteExpiresAfter { + c.Err = model.NewAppError("remoteClusterAcceptMessage", "api.context.invitation_expired.error", nil, "", http.StatusBadRequest) + return + } + + confirm, appErr := model.RemoteClusterInviteFromRawJSON(frame.Msg.Payload) + if appErr != nil { + c.Err = appErr + return + } + + rc.RemoteTeamId = confirm.RemoteTeamId + rc.SiteURL = confirm.SiteURL + rc.RemoteToken = confirm.Token + + if _, err := c.App.UpdateRemoteCluster(rc); err != nil { + c.Err = err + return + } + + auditRec.Success() + ReturnStatusOK(w) +} + +func uploadRemoteData(c *Context, w http.ResponseWriter, r *http.Request) { + if !*c.App.Config().FileSettings.EnableFileAttachments { + c.Err = model.NewAppError("uploadRemoteData", "api.file.attachments.disabled.app_error", + nil, "", http.StatusNotImplemented) + return + } + + c.RequireUploadId() + if c.Err != nil { + return + } + + auditRec := c.MakeAuditRecord("uploadRemoteData", audit.Fail) + defer c.LogAuditRec(auditRec) + auditRec.AddMeta("upload_id", c.Params.UploadId) + + us, err := c.App.GetUploadSession(c.Params.UploadId) + if err != nil { + c.Err = err + return + } + + if us.RemoteId != c.GetRemoteID(r) { + c.Err = model.NewAppError("uploadRemoteData", "api.context.remote_id_mismatch.app_error", + nil, "", http.StatusUnauthorized) + return + } + + info, err := doUploadData(c, us, r) + if err != nil { + c.Err = err + return + } + + auditRec.Success() + + if info == nil { + w.WriteHeader(http.StatusNoContent) + return + } + + w.Write([]byte(info.ToJson())) +} diff --git a/api4/shared_channel.go b/api4/shared_channel.go new file mode 100644 index 00000000000..3baf6931e0d --- /dev/null +++ b/api4/shared_channel.go @@ -0,0 +1,76 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package api4 + +import ( + "encoding/json" + "net/http" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func (api *API) InitSharedChannels() { + api.BaseRoutes.SharedChannels.Handle("/{team_id:[A-Za-z0-9]+}", api.ApiSessionRequired(getSharedChannels)).Methods("GET") + api.BaseRoutes.SharedChannels.Handle("/remote_info/{remote_id:[A-Za-z0-9]+}", api.ApiSessionRequired(getRemoteClusterInfo)).Methods("GET") +} + +func getSharedChannels(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireTeamId() + if c.Err != nil { + return + } + + // make sure remote cluster service is enabled. + if _, appErr := c.App.GetRemoteClusterService(); appErr != nil { + c.Err = appErr + return + } + + opts := model.SharedChannelFilterOpts{ + TeamId: c.Params.TeamId, + } + + channels, appErr := c.App.GetSharedChannels(c.Params.Page, c.Params.PerPage, opts) + if appErr != nil { + c.Err = appErr + return + } + + b, err := json.Marshal(channels) + if err != nil { + c.SetJSONEncodingError() + return + } + w.Write(b) +} + +func getRemoteClusterInfo(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireRemoteId() + if c.Err != nil { + return + } + + // make sure remote cluster service is enabled. + if _, appErr := c.App.GetRemoteClusterService(); appErr != nil { + c.Err = appErr + return + } + + // GetRemoteClusterForUser will only return a remote if the user is a member of at + // least one channel shared by the remote. All other cases return error. + rc, appErr := c.App.GetRemoteClusterForUser(c.Params.RemoteId, c.App.Session().UserId) + if appErr != nil { + c.Err = appErr + return + } + + remoteInfo := rc.ToRemoteClusterInfo() + + b, err := json.Marshal(remoteInfo) + if err != nil { + c.SetJSONEncodingError() + return + } + w.Write(b) +} diff --git a/api4/shared_channel_test.go b/api4/shared_channel_test.go new file mode 100644 index 00000000000..f6d60b54f44 --- /dev/null +++ b/api4/shared_channel_test.go @@ -0,0 +1,229 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package api4 + +import ( + "fmt" + "math/rand" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/app" + "github.com/mattermost/mattermost-server/v5/model" +) + +var ( + rnd = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func TestGetAllSharedChannels(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + const pages = 3 + const pageSize = 7 + + mockService := app.NewMockRemoteClusterService(nil, app.MockOptionRemoteClusterServiceWithActive(true)) + th.App.Srv().SetRemoteClusterService(mockService) + + savedIds := make([]string, 0, pages*pageSize) + + // make some shared channels + for i := 0; i < pages*pageSize; i++ { + channel := th.CreateChannelWithClientAndTeam(th.Client, model.CHANNEL_OPEN, th.BasicTeam.Id) + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + Home: randomBool(), + ShareName: fmt.Sprintf("test_share_%d", i), + CreatorId: th.BasicChannel.CreatorId, + RemoteId: model.NewId(), + } + _, err := th.App.SaveSharedChannel(sc) + require.NoError(t, err) + savedIds = append(savedIds, channel.Id) + } + sort.Strings(savedIds) + + t.Run("get shared channels paginated", func(t *testing.T) { + channelIds := make([]string, 0, 21) + for i := 0; i < pages; i++ { + channels, resp := th.Client.GetAllSharedChannels(th.BasicTeam.Id, i, pageSize) + CheckNoError(t, resp) + channelIds = append(channelIds, getIds(channels)...) + } + sort.Strings(channelIds) + + // ids lists should now match + assert.Equal(t, savedIds, channelIds, "id lists should match") + }) + + t.Run("get shared channels for invalid team", func(t *testing.T) { + channels, resp := th.Client.GetAllSharedChannels(model.NewId(), 0, 100) + CheckNoError(t, resp) + assert.Empty(t, channels) + }) +} + +func getIds(channels []*model.SharedChannel) []string { + ids := make([]string, 0, len(channels)) + for _, c := range channels { + ids = append(ids, c.ChannelId) + } + return ids +} + +func randomBool() bool { + return rnd.Intn(2) != 0 +} + +func TestGetRemoteClusterById(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + mockService := app.NewMockRemoteClusterService(nil, app.MockOptionRemoteClusterServiceWithActive(true)) + th.App.Srv().SetRemoteClusterService(mockService) + + // for this test we need a user that belongs to a channel that + // is shared with the requested remote id. + + // create a remote cluster + rc := &model.RemoteCluster{ + RemoteId: model.NewId(), + DisplayName: "Test1", + RemoteTeamId: model.NewId(), + SiteURL: model.NewId(), + CreatorId: model.NewId(), + } + rc, appErr := th.App.AddRemoteCluster(rc) + require.Nil(t, appErr) + + // create a shared channel + sc := &model.SharedChannel{ + ChannelId: th.BasicChannel.Id, + TeamId: th.BasicChannel.TeamId, + Home: false, + ShareName: "test_share", + CreatorId: th.BasicChannel.CreatorId, + RemoteId: rc.RemoteId, + } + sc, err := th.App.SaveSharedChannel(sc) + require.NoError(t, err) + + // create a shared channel remote to connect them + scr := &model.SharedChannelRemote{ + Id: model.NewId(), + ChannelId: sc.ChannelId, + CreatorId: sc.CreatorId, + IsInviteAccepted: true, + IsInviteConfirmed: true, + RemoteId: sc.RemoteId, + } + _, err = th.App.SaveSharedChannelRemote(scr) + require.NoError(t, err) + + t.Run("valid remote, user is member", func(t *testing.T) { + rcInfo, resp := th.Client.GetRemoteClusterInfo(rc.RemoteId) + CheckNoError(t, resp) + assert.Equal(t, rc.DisplayName, rcInfo.DisplayName) + }) + + t.Run("invalid remote", func(t *testing.T) { + _, resp := th.Client.GetRemoteClusterInfo(model.NewId()) + CheckNotFoundStatus(t, resp) + }) + +} + +func TestCreateDirectChannelWithRemoteUser(t *testing.T) { + t.Run("creates a local DM channel that is shared", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + Client := th.Client + defer Client.Logout() + + localUser := th.BasicUser + remoteUser := th.CreateUser() + remoteUser.RemoteId = model.NewString(model.NewId()) + remoteUser, err := th.App.UpdateUser(remoteUser, false) + require.Nil(t, err) + + dm, resp := Client.CreateDirectChannel(localUser.Id, remoteUser.Id) + CheckNoError(t, resp) + + channelName := model.GetDMNameFromIds(localUser.Id, remoteUser.Id) + require.Equal(t, channelName, dm.Name, "dm name didn't match") + assert.True(t, dm.IsShared()) + }) + + t.Run("sends a shared channel invitation to the remote", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + Client := th.Client + defer Client.Logout() + + mockService := app.NewMockSharedChannelService(nil, app.MockOptionSharedChannelServiceWithActive(true)) + th.App.Srv().SetSharedChannelSyncService(mockService) + + localUser := th.BasicUser + remoteUser := th.CreateUser() + rc := &model.RemoteCluster{ + DisplayName: "test", + Token: model.NewId(), + CreatorId: localUser.Id, + } + rc, err := th.App.AddRemoteCluster(rc) + require.Nil(t, err) + + remoteUser.RemoteId = model.NewString(rc.RemoteId) + remoteUser, err = th.App.UpdateUser(remoteUser, false) + require.Nil(t, err) + + dm, resp := Client.CreateDirectChannel(localUser.Id, remoteUser.Id) + CheckNoError(t, resp) + + channelName := model.GetDMNameFromIds(localUser.Id, remoteUser.Id) + require.Equal(t, channelName, dm.Name, "dm name didn't match") + require.True(t, dm.IsShared()) + + assert.Equal(t, 1, mockService.NumInvitations()) + }) + + t.Run("does not send a shared channel invitation to the remote when creator is remote", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + Client := th.Client + defer Client.Logout() + + mockService := app.NewMockSharedChannelService(nil, app.MockOptionSharedChannelServiceWithActive(true)) + th.App.Srv().SetSharedChannelSyncService(mockService) + + localUser := th.BasicUser + remoteUser := th.CreateUser() + rc := &model.RemoteCluster{ + DisplayName: "test", + Token: model.NewId(), + CreatorId: localUser.Id, + } + rc, err := th.App.AddRemoteCluster(rc) + require.Nil(t, err) + + remoteUser.RemoteId = model.NewString(rc.RemoteId) + remoteUser, err = th.App.UpdateUser(remoteUser, false) + require.Nil(t, err) + + dm, resp := Client.CreateDirectChannel(remoteUser.Id, localUser.Id) + CheckNoError(t, resp) + + channelName := model.GetDMNameFromIds(localUser.Id, remoteUser.Id) + require.Equal(t, channelName, dm.Name, "dm name didn't match") + require.True(t, dm.IsShared()) + + assert.Zero(t, mockService.NumInvitations()) + }) +} diff --git a/api4/upload.go b/api4/upload.go index c1a2ac0b909..285bdccf78a 100644 --- a/api4/upload.go +++ b/api4/upload.go @@ -33,6 +33,10 @@ func createUpload(c *Context, w http.ResponseWriter, r *http.Request) { return } + // these are not supported for client uploads; shared channels only. + us.RemoteId = "" + us.ReqFileId = "" + auditRec := c.MakeAuditRecord("createUpload", audit.Fail) defer c.LogAuditRec(auditRec) auditRec.AddMeta("upload", us) @@ -119,33 +123,7 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) { } } - boundary, parseErr := parseMultipartRequestHeader(r) - if parseErr != nil && !errors.Is(parseErr, http.ErrNotMultipart) { - c.Err = model.NewAppError("uploadData", "api.upload.upload_data.invalid_content_type", - nil, parseErr.Error(), http.StatusBadRequest) - return - } - - var rd io.Reader - if boundary != "" { - mr := multipart.NewReader(r.Body, boundary) - p, partErr := mr.NextPart() - if partErr != nil { - c.Err = model.NewAppError("uploadData", "api.upload.upload_data.multipart_error", - nil, partErr.Error(), http.StatusBadRequest) - return - } - rd = p - } else { - if r.ContentLength > (us.FileSize - us.FileOffset) { - c.Err = model.NewAppError("uploadData", "api.upload.upload_data.invalid_content_length", - nil, "", http.StatusBadRequest) - return - } - rd = r.Body - } - - info, err := c.App.UploadData(us, rd) + info, err := doUploadData(c, us, r) if err != nil { c.Err = err return @@ -160,3 +138,30 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) { w.Write([]byte(info.ToJson())) } + +func doUploadData(c *Context, us *model.UploadSession, r *http.Request) (*model.FileInfo, *model.AppError) { + boundary, parseErr := parseMultipartRequestHeader(r) + if parseErr != nil && !errors.Is(parseErr, http.ErrNotMultipart) { + return nil, model.NewAppError("uploadData", "api.upload.upload_data.invalid_content_type", + nil, parseErr.Error(), http.StatusBadRequest) + } + + var rd io.Reader + if boundary != "" { + mr := multipart.NewReader(r.Body, boundary) + p, partErr := mr.NextPart() + if partErr != nil { + return nil, model.NewAppError("uploadData", "api.upload.upload_data.multipart_error", + nil, partErr.Error(), http.StatusBadRequest) + } + rd = p + } else { + if r.ContentLength > (us.FileSize - us.FileOffset) { + return nil, model.NewAppError("uploadData", "api.upload.upload_data.invalid_content_length", + nil, "", http.StatusBadRequest) + } + rd = r.Body + } + + return c.App.UploadData(us, rd) +} diff --git a/app/app_iface.go b/app/app_iface.go index 85ccb2e7465..7e6de6c3170 100644 --- a/app/app_iface.go +++ b/app/app_iface.go @@ -24,6 +24,7 @@ import ( "github.com/mattermost/mattermost-server/v5/plugin" "github.com/mattermost/mattermost-server/v5/services/httpservice" "github.com/mattermost/mattermost-server/v5/services/imageproxy" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" "github.com/mattermost/mattermost-server/v5/services/searchengine" "github.com/mattermost/mattermost-server/v5/services/timezones" "github.com/mattermost/mattermost-server/v5/shared/filestore" @@ -199,6 +200,8 @@ type AppIface interface { GetTeamSchemeChannelRoles(teamID string) (guestRoleName string, userRoleName string, adminRoleName string, err *model.AppError) // GetTotalUsersStats is used for the DM list total GetTotalUsersStats(viewRestrictions *model.ViewUsersRestrictions) (*model.UsersStats, *model.AppError) + // HasRemote returns whether a given channelID is present in the channel remotes or not. + HasRemote(channelID string, remoteID string) (bool, error) // HubRegister registers a connection to a hub. HubRegister(webConn *WebConn) // HubStart starts all the hubs. @@ -361,6 +364,7 @@ type AppIface interface { AddDirectChannels(teamID string, user *model.User) *model.AppError AddLdapPrivateCertificate(fileData *multipart.FileHeader) *model.AppError AddLdapPublicCertificate(fileData *multipart.FileHeader) *model.AppError + AddRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) AddSamlIdpCertificate(fileData *multipart.FileHeader) *model.AppError AddSamlPrivateCertificate(fileData *multipart.FileHeader) *model.AppError AddSamlPublicCertificate(fileData *multipart.FileHeader) *model.AppError @@ -399,6 +403,7 @@ type AppIface interface { ChannelMembersToAdd(since int64, channelID *string) ([]*model.UserChannelIDPair, *model.AppError) ChannelMembersToRemove(teamID *string) ([]*model.ChannelMember, *model.AppError) CheckAndSendUserLimitWarningEmails() *model.AppError + CheckCanInviteToSharedChannel(channelId string) error CheckForClientSideCert(r *http.Request) (string, string, string) CheckMandatoryS3Fields(settings *model.FileSettings) *model.AppError CheckPasswordAndAllCriteria(user *model.User, password string, mfaToken string) *model.AppError @@ -483,7 +488,10 @@ type AppIface interface { DeletePostFiles(post *model.Post) DeletePreferences(userID string, preferences model.Preferences) *model.AppError DeleteReactionForPost(reaction *model.Reaction) *model.AppError + DeleteRemoteCluster(remoteClusterId string) (bool, *model.AppError) DeleteScheme(schemeId string) (*model.Scheme, *model.AppError) + DeleteSharedChannel(channelID string) (bool, error) + DeleteSharedChannelRemote(id string) (bool, error) DeleteSidebarCategory(userID, teamID, categoryId string) *model.AppError DeleteToken(token *model.Token) *model.AppError DisableAutoResponder(userID string, asAdmin bool) *model.AppError @@ -524,6 +532,7 @@ type AppIface interface { GetAllPublicTeams() ([]*model.Team, *model.AppError) GetAllPublicTeamsPage(offset int, limit int) ([]*model.Team, *model.AppError) GetAllPublicTeamsPageWithCount(offset int, limit int) (*model.TeamsWithCount, *model.AppError) + GetAllRemoteClusters(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, *model.AppError) GetAllRoles() ([]*model.Role, *model.AppError) GetAllStatuses() map[string]*model.Status GetAllTeams() ([]*model.Team, *model.AppError) @@ -626,7 +635,7 @@ type AppIface interface { GetOAuthSignupEndpoint(w http.ResponseWriter, r *http.Request, service, teamID string) (string, *model.AppError) GetOAuthStateToken(token string) (*model.Token, *model.AppError) GetOpenGraphMetadata(requestURL string) *opengraph.OpenGraph - GetOrCreateDirectChannel(userID, otherUserID string) (*model.Channel, *model.AppError) + GetOrCreateDirectChannel(userID, otherUserID string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) GetOutgoingWebhook(hookID string) (*model.OutgoingWebhook, *model.AppError) GetOutgoingWebhooksForChannelPageByUser(channelID string, userID string, page, perPage int) ([]*model.OutgoingWebhook, *model.AppError) GetOutgoingWebhooksForTeamPage(teamID string, page, perPage int) ([]*model.OutgoingWebhook, *model.AppError) @@ -661,6 +670,10 @@ type AppIface interface { GetReactionsForPost(postID string) ([]*model.Reaction, *model.AppError) GetRecentlyActiveUsersForTeam(teamID string) (map[string]*model.User, *model.AppError) GetRecentlyActiveUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError) + GetRemoteCluster(remoteClusterId string) (*model.RemoteCluster, *model.AppError) + GetRemoteClusterForUser(remoteID string, userID string) (*model.RemoteCluster, *model.AppError) + GetRemoteClusterService() (remotecluster.RemoteClusterServiceIFace, *model.AppError) + GetRemoteClusterSession(token string, remoteId string) (*model.Session, *model.AppError) GetRole(id string) (*model.Role, *model.AppError) GetRoleByName(name string) (*model.Role, *model.AppError) GetRolesByNames(names []string) ([]*model.Role, *model.AppError) @@ -676,6 +689,13 @@ type AppIface interface { GetSession(token string) (*model.Session, *model.AppError) GetSessionById(sessionID string) (*model.Session, *model.AppError) GetSessions(userID string) ([]*model.Session, *model.AppError) + GetSharedChannel(channelID string) (*model.SharedChannel, error) + GetSharedChannelRemote(id string) (*model.SharedChannelRemote, error) + GetSharedChannelRemoteByIds(channelID string, remoteID string) (*model.SharedChannelRemote, error) + GetSharedChannelRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) + GetSharedChannelRemotesStatus(channelID string) ([]*model.SharedChannelRemoteStatus, error) + GetSharedChannels(page int, perPage int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, *model.AppError) + GetSharedChannelsCount(opts model.SharedChannelFilterOpts) (int64, error) GetSidebarCategories(userID, teamID string) (*model.OrderedSidebarCategories, *model.AppError) GetSidebarCategory(categoryId string) (*model.SidebarCategoryWithChannels, *model.AppError) GetSidebarCategoryOrder(userID, teamID string) ([]string, *model.AppError) @@ -754,6 +774,7 @@ type AppIface interface { HasPermissionToChannelByPost(askingUserId string, postID string, permission *model.Permission) bool HasPermissionToTeam(askingUserId string, teamID string, permission *model.Permission) bool HasPermissionToUser(askingUserId string, userID string) bool + HasSharedChannel(channelID string) (bool, error) HubStop() ImageProxy() *imageproxy.ImageProxy ImageProxyAdder() func(string) string @@ -884,6 +905,8 @@ type AppIface interface { SaveBrandImage(imageData *multipart.FileHeader) *model.AppError SaveComplianceReport(job *model.Compliance) (*model.Compliance, *model.AppError) SaveReactionForPost(reaction *model.Reaction) (*model.Reaction, *model.AppError) + SaveSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) + SaveSharedChannelRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) SaveUserTermsOfService(userID, termsOfServiceId string, accepted bool) *model.AppError SchemesIterator(scope string, batchSize int) func() []*model.Scheme SearchArchivedChannels(teamID string, term string, userID string) (*model.ChannelList, *model.AppError) @@ -942,6 +965,7 @@ type AppIface interface { SetProfileImage(userID string, imageData *multipart.FileHeader) *model.AppError SetProfileImageFromFile(userID string, file io.Reader) *model.AppError SetProfileImageFromMultiPartFile(userID string, file multipart.File) *model.AppError + SetRemoteClusterLastPingAt(remoteClusterId string) *model.AppError SetRequestId(s string) SetSamlIdpCertificateFromMetadata(data []byte) *model.AppError SetSearchEngine(se *searchengine.Broker) @@ -1009,9 +1033,13 @@ type AppIface interface { UpdatePasswordSendEmail(user *model.User, newPassword, method string) *model.AppError UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model.AppError) UpdatePreferences(userID string, preferences model.Preferences) *model.AppError + UpdateRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) + UpdateRemoteClusterTopics(remoteClusterId string, topics string) (*model.RemoteCluster, *model.AppError) UpdateRole(role *model.Role) (*model.Role, *model.AppError) UpdateScheme(scheme *model.Scheme) (*model.Scheme, *model.AppError) UpdateSessionsIsGuest(userID string, isGuest bool) + UpdateSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) + UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error UpdateSidebarCategories(userID, teamID string, categories []*model.SidebarCategoryWithChannels) ([]*model.SidebarCategoryWithChannels, *model.AppError) UpdateSidebarCategoryOrder(userID, teamID string, categoryOrder []string) *model.AppError UpdateTeam(team *model.Team) (*model.Team, *model.AppError) diff --git a/app/authentication.go b/app/authentication.go index cfafdefde05..7d9b142c1b8 100644 --- a/app/authentication.go +++ b/app/authentication.go @@ -20,6 +20,7 @@ const ( TokenLocationCookie TokenLocationQueryString TokenLocationCloudHeader + TokenLocationRemoteClusterHeader ) func (tl TokenLocation) String() string { @@ -34,6 +35,8 @@ func (tl TokenLocation) String() string { return "QueryString" case TokenLocationCloudHeader: return "CloudHeader" + case TokenLocationRemoteClusterHeader: + return "RemoteClusterHeader" default: return "Unknown" } @@ -291,5 +294,9 @@ func ParseAuthTokenFromRequest(r *http.Request) (string, TokenLocation) { return token, TokenLocationCloudHeader } + if token := r.Header.Get(model.HEADER_REMOTECLUSTER_TOKEN); token != "" { + return token, TokenLocationRemoteClusterHeader + } + return "", TokenLocationNotFound } diff --git a/app/channel.go b/app/channel.go index 261be278656..3a0f4bb6765 100644 --- a/app/channel.go +++ b/app/channel.go @@ -125,7 +125,6 @@ func (a *App) JoinDefaultChannels(teamID string, user *model.User, shouldBeAdmin message.Add("user_id", user.Id) message.Add("team_id", channel.TeamId) a.Publish(message) - } if nErr != nil { @@ -322,7 +321,7 @@ func (a *App) CreateChannel(channel *model.Channel, addMember bool) (*model.Chan return sc, nil } -func (a *App) GetOrCreateDirectChannel(userID, otherUserID string) (*model.Channel, *model.AppError) { +func (a *App) GetOrCreateDirectChannel(userID, otherUserID string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) { channel, nErr := a.getDirectChannel(userID, otherUserID) if nErr != nil { return nil, nErr @@ -332,7 +331,7 @@ func (a *App) GetOrCreateDirectChannel(userID, otherUserID string) (*model.Chann return channel, nil } - channel, err := a.createDirectChannel(userID, otherUserID) + channel, err := a.createDirectChannel(userID, otherUserID, channelOptions...) if err != nil { if err.Id == store.ChannelExistsError { return channel, nil @@ -381,11 +380,12 @@ func (a *App) handleCreationEvent(userID, otherUserID string, channel *model.Cha } message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_DIRECT_ADDED, "", channel.Id, "", nil) + message.Add("creator_id", userID) message.Add("teammate_id", otherUserID) a.Publish(message) } -func (a *App) createDirectChannel(userID, otherUserID string) (*model.Channel, *model.AppError) { +func (a *App) createDirectChannel(userID string, otherUserID string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) { users, err := a.Srv().Store.User().GetMany(context.Background(), []string{userID, otherUserID}) if err != nil { return nil, model.NewAppError("CreateDirectChannel", "api.channel.create_direct_channel.invalid_user.app_error", nil, err.Error(), http.StatusBadRequest) @@ -415,11 +415,11 @@ func (a *App) createDirectChannel(userID, otherUserID string) (*model.Channel, * user = users[1] otherUser = users[0] } - return a.createDirectChannelWithUser(user, otherUser) + return a.createDirectChannelWithUser(user, otherUser, channelOptions...) } -func (a *App) createDirectChannelWithUser(user, otherUser *model.User) (*model.Channel, *model.AppError) { - channel, nErr := a.Srv().Store.Channel().CreateDirectChannel(user, otherUser) +func (a *App) createDirectChannelWithUser(user, otherUser *model.User, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) { + channel, nErr := a.Srv().Store.Channel().CreateDirectChannel(user, otherUser, channelOptions...) if nErr != nil { var invErr *store.ErrInvalidInput var cErr *store.ErrConflict @@ -460,6 +460,27 @@ func (a *App) createDirectChannelWithUser(user, otherUser *model.User) (*model.C } } + // When the newly created channel is shared and the creator is local + // create a local shared channel record + if channel.IsShared() && !user.IsRemote() { + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + Home: true, + ReadOnly: false, + ShareName: channel.Name, + ShareDisplayName: channel.DisplayName, + SharePurpose: channel.Purpose, + ShareHeader: channel.Header, + CreatorId: user.Id, + Type: channel.Type, + } + + if _, err := a.SaveSharedChannel(sc); err != nil { + return nil, model.NewAppError("CreateDirectChannel", "app.sharedchannel.dm_channel_creation.internal_error", nil, err.Error(), http.StatusInternalServerError) + } + } + return channel, nil } diff --git a/app/command_autocomplete_test.go b/app/command_autocomplete_test.go index bc257c79ace..d099374f610 100644 --- a/app/command_autocomplete_test.go +++ b/app/command_autocomplete_test.go @@ -617,7 +617,7 @@ func TestDynamicListArgsForBuiltin(t *testing.T) { th := Setup(t) defer th.TearDown() - provider := &testProvider{} + provider := &testCommandProvider{} RegisterCommandProvider(provider) command := provider.GetCommand(th.App, nil) @@ -633,18 +633,18 @@ func TestDynamicListArgsForBuiltin(t *testing.T) { t.Run("GetAutoCompleteListItems bad arg", func(t *testing.T) { suggestions := th.App.getSuggestions(emptyCmdArgs, []*model.AutocompleteData{command.AutocompleteData}, "", "bogus --badArg ", model.SYSTEM_ADMIN_ROLE_ID) - assert.Len(t, suggestions, 0) + assert.Empty(t, suggestions) }) } -type testProvider struct { +type testCommandProvider struct { } -func (p *testProvider) GetTrigger() string { +func (p *testCommandProvider) GetTrigger() string { return "bogus" } -func (p *testProvider) GetCommand(a *App, T i18n.TranslateFunc) *model.Command { +func (p *testCommandProvider) GetCommand(a *App, T i18n.TranslateFunc) *model.Command { top := model.NewAutocompleteData(p.GetTrigger(), "[command]", "Just a test.") top.AddNamedDynamicListArgument("dynaArg", "A dynamic list", "builtin:bogus", true) @@ -658,14 +658,14 @@ func (p *testProvider) GetCommand(a *App, T i18n.TranslateFunc) *model.Command { } } -func (p *testProvider) DoCommand(a *App, args *model.CommandArgs, message string) *model.CommandResponse { +func (p *testCommandProvider) DoCommand(a *App, args *model.CommandArgs, message string) *model.CommandResponse { return &model.CommandResponse{ Text: "I do nothing!", ResponseType: model.COMMAND_RESPONSE_TYPE_EPHEMERAL, } } -func (p *testProvider) GetAutoCompleteListItems(a *App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg, parsed, toBeParsed string) ([]model.AutocompleteListItem, error) { +func (p *testCommandProvider) GetAutoCompleteListItems(a *App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg, parsed, toBeParsed string) ([]model.AutocompleteListItem, error) { if arg.Name == "dynaArg" { return []model.AutocompleteListItem{ {Item: "item1", Hint: "this is hint 1", HelpText: "This is help text 1."}, diff --git a/app/helper_test.go b/app/helper_test.go index 9802cc9d9e2..aea7a7fae9f 100644 --- a/app/helper_test.go +++ b/app/helper_test.go @@ -282,15 +282,23 @@ func (th *TestHelper) CreateBot() *model.Bot { return bot } -func (th *TestHelper) CreateChannel(team *model.Team) *model.Channel { - return th.createChannel(team, model.CHANNEL_OPEN) +type ChannelOption func(*model.Channel) + +func WithShared(v bool) ChannelOption { + return func(channel *model.Channel) { + channel.Shared = model.NewBool(v) + } +} + +func (th *TestHelper) CreateChannel(team *model.Team, options ...ChannelOption) *model.Channel { + return th.createChannel(team, model.CHANNEL_OPEN, options...) } func (th *TestHelper) CreatePrivateChannel(team *model.Team) *model.Channel { return th.createChannel(team, model.CHANNEL_PRIVATE) } -func (th *TestHelper) createChannel(team *model.Team, channelType string) *model.Channel { +func (th *TestHelper) createChannel(team *model.Team, channelType string, options ...ChannelOption) *model.Channel { id := model.NewId() channel := &model.Channel{ @@ -301,10 +309,31 @@ func (th *TestHelper) createChannel(team *model.Team, channelType string) *model CreatorId: th.BasicUser.Id, } + for _, option := range options { + option(channel) + } + utils.DisableDebugLogForTest() - var err *model.AppError - if channel, err = th.App.CreateChannel(channel, true); err != nil { - panic(err) + var appErr *model.AppError + if channel, appErr = th.App.CreateChannel(channel, true); appErr != nil { + panic(appErr) + } + + if channel.IsShared() { + id := model.NewId() + _, err := th.App.SaveSharedChannel(&model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + Home: false, + ReadOnly: false, + ShareName: "shared-" + id, + ShareDisplayName: "shared-" + id, + CreatorId: th.BasicUser.Id, + RemoteId: model.NewId(), + }) + if err != nil { + panic(err) + } } utils.EnableDebugLogForTest() return channel diff --git a/app/integration_action.go b/app/integration_action.go index 71875246871..fc4cf3ff80d 100644 --- a/app/integration_action.go +++ b/app/integration_action.go @@ -72,7 +72,7 @@ func (a *App) DoPostActionWithCookie(postID, actionId, userID, selectedOption st // Start all queries here for parallel execution pchan := make(chan store.StoreResult, 1) go func() { - post, err := a.Srv().Store.Post().GetSingle(postID) + post, err := a.Srv().Store.Post().GetSingle(postID, false) pchan <- store.StoreResult{Data: post, NErr: err} close(pchan) }() diff --git a/app/integration_action_test.go b/app/integration_action_test.go index beee2628dc3..c914d303e42 100644 --- a/app/integration_action_test.go +++ b/app/integration_action_test.go @@ -419,7 +419,7 @@ func TestPostActionProps(t *testing.T) { require.Nil(t, err) assert.True(t, len(clientTriggerId) == 26) - newPost, nErr := th.App.Srv().Store.Post().GetSingle(post.Id) + newPost, nErr := th.App.Srv().Store.Post().GetSingle(post.Id, false) require.NoError(t, nErr) assert.True(t, newPost.IsPinned) diff --git a/app/notification.go b/app/notification.go index 98ae79ae799..6c6fc78f739 100644 --- a/app/notification.go +++ b/app/notification.go @@ -423,6 +423,7 @@ func (a *App) SendNotifications(post *model.Post, team *model.Team, channel *mod } a.Publish(message) + // If this is a reply in a thread, notify participants if a.Config().FeatureFlags.CollapsedThreads && *a.Config().ServiceSettings.CollapsedThreads != model.COLLAPSED_THREADS_DISABLED && post.RootId != "" { thread, err := a.Srv().Store.Thread().Get(post.RootId) diff --git a/app/opentracing/opentracing_layer.go b/app/opentracing/opentracing_layer.go index c0873aeb05c..c59bdbf8b55 100644 --- a/app/opentracing/opentracing_layer.go +++ b/app/opentracing/opentracing_layer.go @@ -25,6 +25,7 @@ import ( "github.com/mattermost/mattermost-server/v5/plugin" "github.com/mattermost/mattermost-server/v5/services/httpservice" "github.com/mattermost/mattermost-server/v5/services/imageproxy" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" "github.com/mattermost/mattermost-server/v5/services/searchengine" "github.com/mattermost/mattermost-server/v5/services/timezones" "github.com/mattermost/mattermost-server/v5/services/tracing" @@ -235,6 +236,28 @@ func (a *OpenTracingAppLayer) AddPublicKey(name string, key io.Reader) *model.Ap return resultVar0 } +func (a *OpenTracingAppLayer) AddRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.AddRemoteCluster") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.AddRemoteCluster(rc) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) AddSamlIdpCertificate(fileData *multipart.FileHeader) *model.AppError { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.AddSamlIdpCertificate") @@ -1053,6 +1076,28 @@ func (a *OpenTracingAppLayer) CheckAndSendUserLimitWarningEmails() *model.AppErr return resultVar0 } +func (a *OpenTracingAppLayer) CheckCanInviteToSharedChannel(channelId string) error { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CheckCanInviteToSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0 := a.app.CheckCanInviteToSharedChannel(channelId) + + if resultVar0 != nil { + span.LogFields(spanlog.Error(resultVar0)) + ext.Error.Set(span, true) + } + + return resultVar0 +} + func (a *OpenTracingAppLayer) CheckForClientSideCert(r *http.Request) (string, string, string) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CheckForClientSideCert") @@ -3088,6 +3133,28 @@ func (a *OpenTracingAppLayer) DeleteReactionForPost(reaction *model.Reaction) *m return resultVar0 } +func (a *OpenTracingAppLayer) DeleteRemoteCluster(remoteClusterId string) (bool, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DeleteRemoteCluster") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.DeleteRemoteCluster(remoteClusterId) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) DeleteScheme(schemeId string) (*model.Scheme, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DeleteScheme") @@ -3110,6 +3177,50 @@ func (a *OpenTracingAppLayer) DeleteScheme(schemeId string) (*model.Scheme, *mod return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) DeleteSharedChannel(channelID string) (bool, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DeleteSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.DeleteSharedChannel(channelID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) DeleteSharedChannelRemote(id string) (bool, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DeleteSharedChannelRemote") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.DeleteSharedChannelRemote(id) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) DeleteSidebarCategory(userID string, teamID string, categoryId string) *model.AppError { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DeleteSidebarCategory") @@ -4240,6 +4351,28 @@ func (a *OpenTracingAppLayer) GetAllPublicTeamsPageWithCount(offset int, limit i return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetAllRemoteClusters(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetAllRemoteClusters") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetAllRemoteClusters(filter) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetAllRoles() ([]*model.Role, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetAllRoles") @@ -6742,7 +6875,7 @@ func (a *OpenTracingAppLayer) GetOpenGraphMetadata(requestURL string) *opengraph return resultVar0 } -func (a *OpenTracingAppLayer) GetOrCreateDirectChannel(userID string, otherUserID string) (*model.Channel, *model.AppError) { +func (a *OpenTracingAppLayer) GetOrCreateDirectChannel(userID string, otherUserID string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetOrCreateDirectChannel") @@ -6754,7 +6887,7 @@ func (a *OpenTracingAppLayer) GetOrCreateDirectChannel(userID string, otherUserI }() defer span.Finish() - resultVar0, resultVar1 := a.app.GetOrCreateDirectChannel(userID, otherUserID) + resultVar0, resultVar1 := a.app.GetOrCreateDirectChannel(userID, otherUserID, channelOptions...) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) @@ -7629,6 +7762,94 @@ func (a *OpenTracingAppLayer) GetRecentlyActiveUsersForTeamPage(teamID string, p return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetRemoteCluster(remoteClusterId string) (*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetRemoteCluster") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetRemoteCluster(remoteClusterId) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetRemoteClusterForUser(remoteID string, userID string) (*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetRemoteClusterForUser") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetRemoteClusterForUser(remoteID, userID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetRemoteClusterService() (remotecluster.RemoteClusterServiceIFace, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetRemoteClusterService") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetRemoteClusterService() + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetRemoteClusterSession(token string, remoteId string) (*model.Session, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetRemoteClusterSession") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetRemoteClusterSession(token, remoteId) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetRole(id string) (*model.Role, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetRole") @@ -8005,6 +8226,160 @@ func (a *OpenTracingAppLayer) GetSessions(userID string) ([]*model.Session, *mod return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetSharedChannel(channelID string) (*model.SharedChannel, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannel(channelID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannelRemote(id string) (*model.SharedChannelRemote, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannelRemote") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannelRemote(id) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannelRemoteByIds(channelID string, remoteID string) (*model.SharedChannelRemote, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannelRemoteByIds") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannelRemoteByIds(channelID, remoteID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannelRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannelRemotes") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannelRemotes(opts) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannelRemotesStatus(channelID string) ([]*model.SharedChannelRemoteStatus, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannelRemotesStatus") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannelRemotesStatus(channelID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannels(page int, perPage int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannels") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannels(page, perPage, opts) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) GetSharedChannelsCount(opts model.SharedChannelFilterOpts) (int64, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSharedChannelsCount") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSharedChannelsCount(opts) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetSidebarCategories(userID string, teamID string) (*model.OrderedSidebarCategories, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSidebarCategories") @@ -9711,6 +10086,50 @@ func (a *OpenTracingAppLayer) HasPermissionToUser(askingUserId string, userID st return resultVar0 } +func (a *OpenTracingAppLayer) HasRemote(channelID string, remoteID string) (bool, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HasRemote") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.HasRemote(channelID, remoteID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) HasSharedChannel(channelID string) (bool, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HasSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.HasSharedChannel(channelID) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) HubRegister(webConn *app.WebConn) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HubRegister") @@ -12710,6 +13129,50 @@ func (a *OpenTracingAppLayer) SaveReactionForPost(reaction *model.Reaction) (*mo return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) SaveSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.SaveSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.SaveSharedChannel(sc) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) SaveSharedChannelRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.SaveSharedChannelRemote") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.SaveSharedChannelRemote(remote) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) SaveUserTermsOfService(userID string, termsOfServiceId string, accepted bool) *model.AppError { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.SaveUserTermsOfService") @@ -13974,6 +14437,28 @@ func (a *OpenTracingAppLayer) SetProfileImageFromMultiPartFile(userID string, fi return resultVar0 } +func (a *OpenTracingAppLayer) SetRemoteClusterLastPingAt(remoteClusterId string) *model.AppError { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.SetRemoteClusterLastPingAt") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0 := a.app.SetRemoteClusterLastPingAt(remoteClusterId) + + if resultVar0 != nil { + span.LogFields(spanlog.Error(resultVar0)) + ext.Error.Set(span, true) + } + + return resultVar0 +} + func (a *OpenTracingAppLayer) SetSamlIdpCertificateFromMetadata(data []byte) *model.AppError { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.SetSamlIdpCertificateFromMetadata") @@ -15380,6 +15865,50 @@ func (a *OpenTracingAppLayer) UpdateProductNotices() *model.AppError { return resultVar0 } +func (a *OpenTracingAppLayer) UpdateRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateRemoteCluster") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.UpdateRemoteCluster(rc) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) UpdateRemoteClusterTopics(remoteClusterId string, topics string) (*model.RemoteCluster, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateRemoteClusterTopics") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.UpdateRemoteClusterTopics(remoteClusterId, topics) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) UpdateRole(role *model.Role) (*model.Role, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateRole") @@ -15439,6 +15968,50 @@ func (a *OpenTracingAppLayer) UpdateSessionsIsGuest(userID string, isGuest bool) a.app.UpdateSessionsIsGuest(userID, isGuest) } +func (a *OpenTracingAppLayer) UpdateSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateSharedChannel") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.UpdateSharedChannel(sc) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + +func (a *OpenTracingAppLayer) UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateSharedChannelRemoteNextSyncAt") + + a.ctx = newCtx + a.app.Srv().Store.SetContext(newCtx) + defer func() { + a.app.Srv().Store.SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0 := a.app.UpdateSharedChannelRemoteNextSyncAt(id, syncTime) + + if resultVar0 != nil { + span.LogFields(spanlog.Error(resultVar0)) + ext.Error.Set(span, true) + } + + return resultVar0 +} + func (a *OpenTracingAppLayer) UpdateSidebarCategories(userID string, teamID string, categories []*model.SidebarCategoryWithChannels) ([]*model.SidebarCategoryWithChannels, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateSidebarCategories") diff --git a/app/plugin_hooks_test.go b/app/plugin_hooks_test.go index 98a00e6fefc..0b159355663 100644 --- a/app/plugin_hooks_test.go +++ b/app/plugin_hooks_test.go @@ -181,7 +181,7 @@ func TestHookMessageWillBePosted(t *testing.T) { require.Nil(t, err) assert.Equal(t, "message", post.Message) - retrievedPost, errSingle := th.App.Srv().Store.Post().GetSingle(post.Id) + retrievedPost, errSingle := th.App.Srv().Store.Post().GetSingle(post.Id, false) require.NoError(t, errSingle) assert.Equal(t, "message", retrievedPost.Message) }) @@ -225,7 +225,7 @@ func TestHookMessageWillBePosted(t *testing.T) { require.Nil(t, err) assert.Equal(t, "message_fromplugin", post.Message) - retrievedPost, errSingle := th.App.Srv().Store.Post().GetSingle(post.Id) + retrievedPost, errSingle := th.App.Srv().Store.Post().GetSingle(post.Id, false) require.NoError(t, errSingle) assert.Equal(t, "message_fromplugin", retrievedPost.Message) }) diff --git a/app/post.go b/app/post.go index 4260e5bb7d7..8c388ee5de0 100644 --- a/app/post.go +++ b/app/post.go @@ -610,6 +610,10 @@ func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model return nil, err } + if post.IsRemote() { + oldPost.RemoteId = model.NewString(*post.RemoteId) + } + if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil { var rejectionReason string pluginContext := a.PluginContext() @@ -728,7 +732,7 @@ func (a *App) GetPostsSince(options model.GetPostsSinceOptions) (*model.PostList } func (a *App) GetSinglePost(postID string) (*model.Post, *model.AppError) { - post, err := a.Srv().Store.Post().GetSingle(postID) + post, err := a.Srv().Store.Post().GetSingle(postID, false) if err != nil { var nfErr *store.ErrNotFound switch { @@ -1012,7 +1016,7 @@ func (a *App) GetPostsForChannelAroundLastUnread(channelID, userID string, limit } func (a *App) DeletePost(postID, deleteByID string) (*model.Post, *model.AppError) { - post, nErr := a.Srv().Store.Post().GetSingle(postID) + post, nErr := a.Srv().Store.Post().GetSingle(postID, false) if nErr != nil { return nil, model.NewAppError("DeletePost", "app.post.get.app_error", nil, nErr.Error(), http.StatusBadRequest) } @@ -1237,7 +1241,7 @@ func (a *App) GetFileInfosForPostWithMigration(postID string) ([]*model.FileInfo pchan := make(chan store.StoreResult, 1) go func() { - post, err := a.Srv().Store.Post().GetSingle(postID) + post, err := a.Srv().Store.Post().GetSingle(postID, false) pchan <- store.StoreResult{Data: post, NErr: err} close(pchan) }() diff --git a/app/post_test.go b/app/post_test.go index 775da0a01da..6832f790a57 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -2053,3 +2053,87 @@ func TestReplyToPostWithLag(t *testing.T) { require.NotNil(t, reply) }) } + +func TestSharedChannelSyncForPostActions(t *testing.T) { + t.Run("creating a post in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteClusterService := NewMockSharedChannelService(nil) + th.App.srv.sharedChannelService = remoteClusterService + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + user := th.BasicUser + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + _, err := th.App.CreatePost(&model.Post{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "Hello folks", + }, channel, false, true) + require.Nil(t, err, "Creating a post should not error") + + assert.Len(t, remoteClusterService.notifications, 1) + assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) + }) + + t.Run("updating a post in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteClusterService := NewMockSharedChannelService(nil) + th.App.srv.sharedChannelService = remoteClusterService + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + user := th.BasicUser + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + post, err := th.App.CreatePost(&model.Post{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "Hello folks", + }, channel, false, true) + require.Nil(t, err, "Creating a post should not error") + + _, err = th.App.UpdatePost(post, true) + require.Nil(t, err, "Updating a post should not error") + + assert.Len(t, remoteClusterService.notifications, 2) + assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) + assert.Equal(t, channel.Id, remoteClusterService.notifications[1]) + }) + + t.Run("deleting a post in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteClusterService := NewMockSharedChannelService(nil) + th.App.srv.sharedChannelService = remoteClusterService + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + user := th.BasicUser + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + post, err := th.App.CreatePost(&model.Post{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "Hello folks", + }, channel, false, true) + require.Nil(t, err, "Creating a post should not error") + + _, err = th.App.DeletePost(post.Id, user.Id) + require.Nil(t, err, "Deleting a post should not error") + + // one creation and two deletes + assert.Len(t, remoteClusterService.notifications, 3) + assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) + assert.Equal(t, channel.Id, remoteClusterService.notifications[1]) + assert.Equal(t, channel.Id, remoteClusterService.notifications[2]) + }) +} diff --git a/app/reaction_test.go b/app/reaction_test.go new file mode 100644 index 00000000000..9acf3f2d333 --- /dev/null +++ b/app/reaction_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/testlib" +) + +func TestSharedChannelSyncForReactionActions(t *testing.T) { + t.Run("adding a reaction in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { + th := Setup(t).InitBasic() + + sharedChannelService := NewMockSharedChannelService(nil) + th.App.srv.sharedChannelService = sharedChannelService + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + user := th.BasicUser + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + post, err := th.App.CreatePost(&model.Post{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "Hello folks", + }, channel, false, true) + require.Nil(t, err, "Creating a post should not error") + + reaction := &model.Reaction{ + UserId: user.Id, + PostId: post.Id, + EmojiName: "+1", + } + + _, err = th.App.SaveReactionForPost(reaction) + require.Nil(t, err, "Adding a reaction should not error") + + th.TearDown() // We need to enforce teardown because reaction instrumentation happens in a goroutine + + assert.Len(t, sharedChannelService.notifications, 2) + assert.Equal(t, channel.Id, sharedChannelService.notifications[0]) + assert.Equal(t, channel.Id, sharedChannelService.notifications[1]) + }) + + t.Run("removing a reaction in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { + th := Setup(t).InitBasic() + + sharedChannelService := NewMockSharedChannelService(nil) + th.App.srv.sharedChannelService = sharedChannelService + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + user := th.BasicUser + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + post, err := th.App.CreatePost(&model.Post{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "Hello folks", + }, channel, false, true) + require.Nil(t, err, "Creating a post should not error") + + reaction := &model.Reaction{ + UserId: user.Id, + PostId: post.Id, + EmojiName: "+1", + } + + err = th.App.DeleteReactionForPost(reaction) + require.Nil(t, err, "Adding a reaction should not error") + + th.TearDown() // We need to enforce teardown because reaction instrumentation happens in a goroutine + + assert.Len(t, sharedChannelService.notifications, 2) + assert.Equal(t, channel.Id, sharedChannelService.notifications[0]) + assert.Equal(t, channel.Id, sharedChannelService.notifications[1]) + }) +} diff --git a/app/remote_cluster.go b/app/remote_cluster.go new file mode 100644 index 00000000000..d9f544a789b --- /dev/null +++ b/app/remote_cluster.go @@ -0,0 +1,87 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "net/http" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/store/sqlstore" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func (a *App) AddRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) { + rc, err := a.Srv().Store.RemoteCluster().Save(rc) + if err != nil { + if sqlstore.IsUniqueConstraintError(errors.Cause(err), []string{sqlstore.RemoteClusterSiteURLUniqueIndex}) { + return nil, model.NewAppError("AddRemoteCluster", "api.remote_cluster.save_not_unique.app_error", nil, err.Error(), http.StatusInternalServerError) + } + + return nil, model.NewAppError("AddRemoteCluster", "api.remote_cluster.save.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return rc, nil +} + +func (a *App) UpdateRemoteCluster(rc *model.RemoteCluster) (*model.RemoteCluster, *model.AppError) { + rc, err := a.Srv().Store.RemoteCluster().Update(rc) + if err != nil { + if sqlstore.IsUniqueConstraintError(errors.Cause(err), []string{sqlstore.RemoteClusterSiteURLUniqueIndex}) { + return nil, model.NewAppError("UpdateRemoteCluster", "api.remote_cluster.update_not_unique.app_error", nil, err.Error(), http.StatusInternalServerError) + } + + return nil, model.NewAppError("UpdateRemoteCluster", "api.remote_cluster.update.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return rc, nil +} + +func (a *App) DeleteRemoteCluster(remoteClusterId string) (bool, *model.AppError) { + deleted, err := a.Srv().Store.RemoteCluster().Delete(remoteClusterId) + if err != nil { + return false, model.NewAppError("DeleteRemoteCluster", "api.remote_cluster.delete.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return deleted, nil +} + +func (a *App) GetRemoteCluster(remoteClusterId string) (*model.RemoteCluster, *model.AppError) { + rc, err := a.Srv().Store.RemoteCluster().Get(remoteClusterId) + if err != nil { + return nil, model.NewAppError("GetRemoteCluster", "api.remote_cluster.get.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return rc, nil +} + +func (a *App) GetAllRemoteClusters(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, *model.AppError) { + list, err := a.Srv().Store.RemoteCluster().GetAll(filter) + if err != nil { + return nil, model.NewAppError("GetAllRemoteClusters", "api.remote_cluster.get.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return list, nil +} + +func (a *App) UpdateRemoteClusterTopics(remoteClusterId string, topics string) (*model.RemoteCluster, *model.AppError) { + rc, err := a.Srv().Store.RemoteCluster().UpdateTopics(remoteClusterId, topics) + if err != nil { + return nil, model.NewAppError("UpdateRemoteClusterTopics", "api.remote_cluster.save.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return rc, nil +} + +func (a *App) SetRemoteClusterLastPingAt(remoteClusterId string) *model.AppError { + err := a.Srv().Store.RemoteCluster().SetLastPingAt(remoteClusterId) + if err != nil { + return model.NewAppError("SetRemoteClusterLastPingAt", "api.remote_cluster.save.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return nil +} + +func (a *App) GetRemoteClusterService() (remotecluster.RemoteClusterServiceIFace, *model.AppError) { + service := a.Srv().GetRemoteClusterService() + if service == nil { + return nil, model.NewAppError("GetRemoteClusterService", "api.remote_cluster.service_not_enabled.app_error", nil, "", http.StatusNotImplemented) + } + return service, nil +} diff --git a/app/remote_cluster_service_mock.go b/app/remote_cluster_service_mock.go new file mode 100644 index 00000000000..0e85e2d8760 --- /dev/null +++ b/app/remote_cluster_service_mock.go @@ -0,0 +1,75 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "context" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" +) + +// MockOptionRemoteClusterService a mock of the remote cluster service +type MockOptionRemoteClusterService func(service *mockRemoteClusterService) + +func MockOptionRemoteClusterServiceWithActive(active bool) MockOptionRemoteClusterService { + return func(mrcs *mockRemoteClusterService) { + mrcs.active = active + } +} + +func NewMockRemoteClusterService(service remotecluster.RemoteClusterServiceIFace, options ...MockOptionRemoteClusterService) *mockRemoteClusterService { + mrcs := &mockRemoteClusterService{service, true} + for _, option := range options { + option(mrcs) + } + return mrcs +} + +type mockRemoteClusterService struct { + remotecluster.RemoteClusterServiceIFace + active bool +} + +func (mrcs *mockRemoteClusterService) Shutdown() error { + return nil +} + +func (mrcs *mockRemoteClusterService) Start() error { + return nil +} + +func (mrcs *mockRemoteClusterService) Active() bool { + return mrcs.active +} + +func (mrcs *mockRemoteClusterService) AddTopicListener(topic string, listener remotecluster.TopicListener) string { + return model.NewId() +} + +func (mrcs *mockRemoteClusterService) RemoveTopicListener(listenerId string) { +} + +func (mrcs *mockRemoteClusterService) AddConnectionStateListener(listener remotecluster.ConnectionStateListener) string { + return model.NewId() +} + +func (mrcs *mockRemoteClusterService) RemoveConnectionStateListener(listenerId string) { +} + +func (mrcs *mockRemoteClusterService) SendMsg(ctx context.Context, msg model.RemoteClusterMsg, rc *model.RemoteCluster, f remotecluster.SendMsgResultFunc) error { + return nil +} + +func (mrcs *mockRemoteClusterService) SendFile(ctx context.Context, us *model.UploadSession, fi *model.FileInfo, rc *model.RemoteCluster, rp remotecluster.ReaderProvider, f remotecluster.SendFileResultFunc) error { + return nil +} + +func (mrcs *mockRemoteClusterService) AcceptInvitation(invite *model.RemoteClusterInvite, name string, creatorId string, teamId string, siteURL string) (*model.RemoteCluster, error) { + return nil, nil +} + +func (mrcs *mockRemoteClusterService) ReceiveIncomingMsg(rc *model.RemoteCluster, msg model.RemoteClusterMsg) remotecluster.Response { + return remotecluster.Response{} +} diff --git a/app/remote_cluster_test.go b/app/remote_cluster_test.go new file mode 100644 index 00000000000..0afe1797f90 --- /dev/null +++ b/app/remote_cluster_test.go @@ -0,0 +1,152 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func TestAddRemoteCluster(t *testing.T) { + t.Run("adding remote cluster with duplicate site url and remote team id", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8065", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + _, err := th.App.AddRemoteCluster(remoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + remoteCluster.RemoteId = model.NewId() + _, err = th.App.AddRemoteCluster(remoteCluster) + require.Error(t, err, "Adding a duplicate remote cluster should error") + assert.Contains(t, err.Error(), "Remote cluster has already been added.") + }) + + t.Run("adding remote cluster with duplicate site url or remote team id is allowed", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8065", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + existingRemoteCluster, err := th.App.AddRemoteCluster(remoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + // Same site url but different remote team id + remoteCluster.RemoteId = model.NewId() + remoteCluster.RemoteTeamId = model.NewId() + remoteCluster.SiteURL = existingRemoteCluster.SiteURL + _, err = th.App.AddRemoteCluster(remoteCluster) + assert.Nil(t, err, "Adding a remote cluster should not error") + + // Same remote team id but different site url + remoteCluster.RemoteId = model.NewId() + remoteCluster.RemoteTeamId = existingRemoteCluster.RemoteTeamId + remoteCluster.SiteURL = existingRemoteCluster.SiteURL + "/new" + _, err = th.App.AddRemoteCluster(remoteCluster) + assert.Nil(t, err, "Adding a remote cluster should not error") + }) +} + +func TestUpdateRemoteCluster(t *testing.T) { + t.Run("update remote cluster with an already existing site url and team id", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8065", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + otherRemoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8066", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + _, err := th.App.AddRemoteCluster(remoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + savedRemoteClustered, err := th.App.AddRemoteCluster(otherRemoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + savedRemoteClustered.SiteURL = remoteCluster.SiteURL + savedRemoteClustered.RemoteTeamId = remoteCluster.RemoteTeamId + _, err = th.App.UpdateRemoteCluster(savedRemoteClustered) + require.Error(t, err, "Updating remote cluster with duplicate site url should error") + assert.Contains(t, err.Error(), "Remote cluster with the same url already exists.") + }) + + t.Run("update remote cluster with an already existing site url or team id, is allowed", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + remoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8065", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + otherRemoteCluster := &model.RemoteCluster{ + RemoteTeamId: model.NewId(), + DisplayName: "test", + SiteURL: "http://localhost:8066", + Token: "test", + RemoteToken: "test", + Topics: "", + CreatorId: th.BasicUser.Id, + } + + existingRemoteCluster, err := th.App.AddRemoteCluster(remoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + anotherExistingRemoteClustered, err := th.App.AddRemoteCluster(otherRemoteCluster) + require.Nil(t, err, "Adding a remote cluster should not error") + + // Same site url but different remote team id + anotherExistingRemoteClustered.SiteURL = existingRemoteCluster.SiteURL + anotherExistingRemoteClustered.RemoteTeamId = model.NewId() + _, err = th.App.UpdateRemoteCluster(anotherExistingRemoteClustered) + assert.Nil(t, err, "Updating remote cluster should not error") + + // Same remote team id but different site url + anotherExistingRemoteClustered.SiteURL = existingRemoteCluster.SiteURL + "/new" + anotherExistingRemoteClustered.RemoteTeamId = existingRemoteCluster.RemoteTeamId + _, err = th.App.UpdateRemoteCluster(anotherExistingRemoteClustered) + assert.Nil(t, err, "Updating remote cluster should not error") + }) +} diff --git a/app/server.go b/app/server.go index b538aea7496..c91d614778f 100644 --- a/app/server.go +++ b/app/server.go @@ -47,8 +47,10 @@ import ( "github.com/mattermost/mattermost-server/v5/services/cache" "github.com/mattermost/mattermost-server/v5/services/httpservice" "github.com/mattermost/mattermost-server/v5/services/imageproxy" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" "github.com/mattermost/mattermost-server/v5/services/searchengine" "github.com/mattermost/mattermost-server/v5/services/searchengine/bleveengine" + "github.com/mattermost/mattermost-server/v5/services/sharedchannel" "github.com/mattermost/mattermost-server/v5/services/telemetry" "github.com/mattermost/mattermost-server/v5/services/timezones" "github.com/mattermost/mattermost-server/v5/services/tracing" @@ -157,6 +159,9 @@ type Server struct { telemetryService *telemetry.TelemetryService + remoteClusterService remotecluster.RemoteClusterServiceIFace + sharedChannelService SharedChannelServiceIFace + phase2PermissionsMigrationComplete bool HTTPService httpservice.HTTPService @@ -808,6 +813,64 @@ func (s *Server) removeUnlicensedLogTargets(license *model.License) { }) } +func (s *Server) startInterClusterServices(license *model.License, app *App) error { + if license == nil { + mlog.Debug("No license provided; Remote Cluster services disabled") + return nil + } + + // Remote Cluster service + + // License check + if !*license.Features.RemoteClusterService { + mlog.Debug("License does not have Remote Cluster services enabled") + return nil + } + + // Config check + if !*s.Config().ExperimentalSettings.EnableRemoteClusterService { + mlog.Debug("Remote Cluster Service disabled via config") + return nil + } + + var err error + + s.remoteClusterService, err = remotecluster.NewRemoteClusterService(s) + if err != nil { + return err + } + + if err = s.remoteClusterService.Start(); err != nil { + s.remoteClusterService = nil + return err + } + + // Shared Channels service + + // License check + if !*license.Features.SharedChannels { + mlog.Debug("License does not have shared channels enabled") + return nil + } + + // Config check + if !*s.Config().ExperimentalSettings.EnableSharedChannels { + mlog.Debug("Shared Channels Service disabled via config") + return nil + } + + s.sharedChannelService, err = sharedchannel.NewSharedChannelService(s, app) + if err != nil { + return err + } + + if err = s.sharedChannelService.Start(); err != nil { + s.remoteClusterService = nil + return err + } + return nil +} + func (s *Server) enableLoggingMetrics() { if s.Metrics == nil { return @@ -866,6 +929,12 @@ func (s *Server) Shutdown() { mlog.Warn("Unable to cleanly shutdown telemetry client", mlog.Err(err)) } + if s.remoteClusterService != nil { + if err = s.remoteClusterService.Shutdown(); err != nil { + mlog.Error("Error shutting down intercluster services", mlog.Err(err)) + } + } + s.StopHTTPServer() s.stopLocalModeServer() // Push notification hub needs to be shutdown after HTTP server @@ -1231,6 +1300,10 @@ func (s *Server) Start() error { } } + if err := s.startInterClusterServices(s.License(), s.WebSocketRouter.app); err != nil { + mlog.Error("Error starting inter-cluster services", mlog.Err(err)) + } + return nil } @@ -1799,6 +1872,46 @@ func (s *Server) SetLog(l *mlog.Logger) { s.Log = l } +func (s *Server) GetLogger() mlog.LoggerIFace { + return s.Log +} + +// GetStore returns the server's Store. Exposing via a method +// allows interfaces to be created with subsets of server APIs. +func (s *Server) GetStore() store.Store { + return s.Store +} + +// GetRemoteClusterService returns the `RemoteClusterService` instantiated by the server. +// May be nil if the service is not enabled via license. +func (s *Server) GetRemoteClusterService() remotecluster.RemoteClusterServiceIFace { + return s.remoteClusterService +} + +// GetSharedChannelSyncService returns the `SharedChannelSyncService` instantiated by the server. +// May be nil if the service is not enabled via license. +func (s *Server) GetSharedChannelSyncService() SharedChannelServiceIFace { + return s.sharedChannelService +} + +// GetMetrics returns the server's Metrics interface. Exposing via a method +// allows interfaces to be created with subsets of server APIs. +func (s *Server) GetMetrics() einterfaces.MetricsInterface { + return s.Metrics +} + +// SetRemoteClusterService sets the `RemoteClusterService` to be used by the server. +// For testing only. +func (s *Server) SetRemoteClusterService(remoteClusterService remotecluster.RemoteClusterServiceIFace) { + s.remoteClusterService = remoteClusterService +} + +// SetSharedChannelSyncService sets the `SharedChannelSyncService` to be used by the server. +// For testing only. +func (s *Server) SetSharedChannelSyncService(sharedChannelService SharedChannelServiceIFace) { + s.sharedChannelService = sharedChannelService +} + func (a *App) GenerateSupportPacket() []model.FileData { // If any errors we come across within this function, we will log it in a warning.txt file so that we know why certain files did not get produced if any var warnings []string diff --git a/app/session.go b/app/session.go index f2f8da2a71f..ceccb4099f5 100644 --- a/app/session.go +++ b/app/session.go @@ -67,6 +67,21 @@ func (a *App) GetCloudSession(token string) (*model.Session, *model.AppError) { return nil, model.NewAppError("GetCloudSession", "api.context.invalid_token.error", map[string]interface{}{"Token": token, "Error": ""}, "The provided token is invalid", http.StatusUnauthorized) } +func (a *App) GetRemoteClusterSession(token string, remoteId string) (*model.Session, *model.AppError) { + rc, appErr := a.GetRemoteCluster(remoteId) + if appErr == nil && rc.Token == token { + // Need a bare-bones session object for later checks + session := &model.Session{ + Token: token, + IsOAuth: false, + } + + session.AddProp(model.SESSION_PROP_TYPE, model.SESSION_TYPE_REMOTECLUSTER_TOKEN) + return session, nil + } + return nil, model.NewAppError("GetRemoteClusterSession", "api.context.invalid_token.error", map[string]interface{}{"Token": token, "Error": ""}, "The provided token is invalid", http.StatusUnauthorized) +} + func (a *App) GetSession(token string) (*model.Session, *model.AppError) { metrics := a.Metrics() diff --git a/app/session_test.go b/app/session_test.go index 2cc7a3d3b5e..977e24d6df4 100644 --- a/app/session_test.go +++ b/app/session_test.go @@ -439,3 +439,39 @@ func TestGetCloudSession(t *testing.T) { require.Equal(t, "api.context.invalid_token.error", err.Id) }) } + +func TestGetRemoteClusterSession(t *testing.T) { + th := Setup(t) + token := model.NewId() + remoteId := model.NewId() + + rc := model.RemoteCluster{ + RemoteId: remoteId, + RemoteTeamId: model.NewId(), + DisplayName: "test", + Token: token, + CreatorId: model.NewId(), + } + + _, err := th.GetSqlStore().RemoteCluster().Save(&rc) + require.NoError(t, err) + + t.Run("Valid remote token should return session", func(t *testing.T) { + session, err := th.App.GetRemoteClusterSession(token, remoteId) + require.Nil(t, err) + require.NotNil(t, session) + require.Equal(t, token, session.Token) + }) + + t.Run("Invalid remote token should return error", func(t *testing.T) { + session, err := th.App.GetRemoteClusterSession(model.NewId(), remoteId) + require.Error(t, err) + require.Nil(t, session) + }) + + t.Run("Invalid remote id should return error", func(t *testing.T) { + session, err := th.App.GetRemoteClusterSession(token, model.NewId()) + require.Error(t, err) + require.Nil(t, session) + }) +} diff --git a/app/shared_channel.go b/app/shared_channel.go new file mode 100644 index 00000000000..7c5a75d3a3e --- /dev/null +++ b/app/shared_channel.go @@ -0,0 +1,149 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "errors" + "fmt" + "net/http" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/store" +) + +func (a *App) checkChannelNotShared(channelId string) error { + // check that channel exists. + if _, err := a.GetChannel(channelId); err != nil { + return fmt.Errorf("cannot share this channel: %w", err) + } + + // Check channel is not already shared. + if _, err := a.GetSharedChannel(channelId); err == nil { + var errNotFound *store.ErrNotFound + if errors.As(err, &errNotFound) { + return errors.New("channel is already shared.") + } + return fmt.Errorf("cannot find channel: %w", err) + } + return nil +} + +func (a *App) checkChannelIsShared(channelId string) error { + if _, err := a.GetSharedChannel(channelId); err != nil { + var errNotFound *store.ErrNotFound + if errors.As(err, &errNotFound) { + return errors.New("channel is not shared.") + } + return fmt.Errorf("cannot find channel: %w", err) + } + return nil +} + +func (a *App) CheckCanInviteToSharedChannel(channelId string) error { + sc, err := a.GetSharedChannel(channelId) + if err != nil { + var errNotFound *store.ErrNotFound + if errors.As(err, &errNotFound) { + return errors.New("channel is not shared.") + } + return fmt.Errorf("cannot find channel: %w", err) + } + + if !sc.Home { + return errors.New("channel is homed on a remote cluster.") + } + return nil +} + +// SharedChannels + +func (a *App) SaveSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) { + if err := a.checkChannelNotShared(sc.ChannelId); err != nil { + return nil, err + } + return a.Srv().Store.SharedChannel().Save(sc) +} + +func (a *App) GetSharedChannel(channelID string) (*model.SharedChannel, error) { + return a.Srv().Store.SharedChannel().Get(channelID) +} + +func (a *App) HasSharedChannel(channelID string) (bool, error) { + return a.Srv().Store.SharedChannel().HasChannel(channelID) +} + +func (a *App) GetSharedChannels(page int, perPage int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, *model.AppError) { + channels, err := a.Srv().Store.SharedChannel().GetAll(page*perPage, perPage, opts) + if err != nil { + return nil, model.NewAppError("GetSharedChannels", "app.channel.get_channels.not_found.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return channels, nil +} + +func (a *App) GetSharedChannelsCount(opts model.SharedChannelFilterOpts) (int64, error) { + return a.Srv().Store.SharedChannel().GetAllCount(opts) +} + +func (a *App) UpdateSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) { + return a.Srv().Store.SharedChannel().Update(sc) +} + +func (a *App) DeleteSharedChannel(channelID string) (bool, error) { + return a.Srv().Store.SharedChannel().Delete(channelID) +} + +// SharedChannelRemotes + +func (a *App) SaveSharedChannelRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + if err := a.checkChannelIsShared(remote.ChannelId); err != nil { + return nil, err + } + return a.Srv().Store.SharedChannel().SaveRemote(remote) +} + +func (a *App) GetSharedChannelRemote(id string) (*model.SharedChannelRemote, error) { + return a.Srv().Store.SharedChannel().GetRemote(id) +} + +func (a *App) GetSharedChannelRemoteByIds(channelID string, remoteID string) (*model.SharedChannelRemote, error) { + return a.Srv().Store.SharedChannel().GetRemoteByIds(channelID, remoteID) +} + +func (a *App) GetSharedChannelRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + return a.Srv().Store.SharedChannel().GetRemotes(opts) +} + +// HasRemote returns whether a given channelID is present in the channel remotes or not. +func (a *App) HasRemote(channelID string, remoteID string) (bool, error) { + return a.Srv().Store.SharedChannel().HasRemote(channelID, remoteID) +} + +func (a *App) GetRemoteClusterForUser(remoteID string, userID string) (*model.RemoteCluster, *model.AppError) { + rc, err := a.Srv().Store.SharedChannel().GetRemoteForUser(remoteID, userID) + if err != nil { + var nfErr *store.ErrNotFound + switch { + case errors.As(err, &nfErr): + return nil, model.NewAppError("GetRemoteClusterForUser", "api.context.remote_id_invalid.app_error", nil, nfErr.Error(), http.StatusNotFound) + default: + return nil, model.NewAppError("GetRemoteClusterForUser", "api.context.remote_id_invalid.app_error", nil, err.Error(), http.StatusInternalServerError) + } + } + return rc, nil +} + +func (a *App) UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error { + return a.Srv().Store.SharedChannel().UpdateRemoteNextSyncAt(id, syncTime) +} + +func (a *App) DeleteSharedChannelRemote(id string) (bool, error) { + return a.Srv().Store.SharedChannel().DeleteRemote(id) +} + +func (a *App) GetSharedChannelRemotesStatus(channelID string) ([]*model.SharedChannelRemoteStatus, error) { + if err := a.checkChannelIsShared(channelID); err != nil { + return nil, err + } + return a.Srv().Store.SharedChannel().GetRemotesStatus(channelID) +} diff --git a/app/shared_channel_notifier.go b/app/shared_channel_notifier.go new file mode 100644 index 00000000000..d2901d1dbe5 --- /dev/null +++ b/app/shared_channel_notifier.go @@ -0,0 +1,144 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "context" + "fmt" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/sharedchannel" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +var sharedChannelEventsForSync model.StringArray = []string{ + model.WEBSOCKET_EVENT_POSTED, + model.WEBSOCKET_EVENT_POST_EDITED, + model.WEBSOCKET_EVENT_POST_DELETED, + model.WEBSOCKET_EVENT_REACTION_ADDED, + model.WEBSOCKET_EVENT_REACTION_REMOVED, +} + +var sharedChannelEventsForInvitation model.StringArray = []string{ + model.WEBSOCKET_EVENT_DIRECT_ADDED, +} + +// SharedChannelSyncHandler is called when a websocket event is received by a cluster node. +// Only on the leader node it will notify the sync service to perform necessary updates to the remote for the given +// shared channel. +func (s *Server) SharedChannelSyncHandler(event *model.WebSocketEvent) { + syncService := s.GetSharedChannelSyncService() + if isEligibleForEvents(syncService, event, sharedChannelEventsForSync) { + err := handleContentSync(s, syncService, event) + if err != nil { + mlog.Warn( + err.Error(), + mlog.String("event", event.EventType()), + mlog.String("action", "content_sync"), + ) + } + } else if isEligibleForEvents(syncService, event, sharedChannelEventsForInvitation) { + err := handleInvitation(s, syncService, event) + if err != nil { + mlog.Warn( + err.Error(), + mlog.String("event", event.EventType()), + mlog.String("action", "invitation"), + ) + } + } +} + +func isEligibleForEvents(syncService SharedChannelServiceIFace, event *model.WebSocketEvent, events model.StringArray) bool { + return syncServiceEnabled(syncService) && + eventHasChannel(event) && + events.Contains(event.EventType()) +} + +func eventHasChannel(event *model.WebSocketEvent) bool { + return event.GetBroadcast() != nil && + event.GetBroadcast().ChannelId != "" +} + +func syncServiceEnabled(syncService SharedChannelServiceIFace) bool { + return syncService != nil && + syncService.Active() +} + +func handleContentSync(s *Server, syncService SharedChannelServiceIFace, event *model.WebSocketEvent) error { + channel, err := findChannel(s, event.GetBroadcast().ChannelId) + if err != nil { + return err + } + + if channel != nil && channel.IsShared() { + syncService.NotifyChannelChanged(channel.Id) + } + + return nil +} + +func handleInvitation(s *Server, syncService SharedChannelServiceIFace, event *model.WebSocketEvent) error { + channel, err := findChannel(s, event.GetBroadcast().ChannelId) + if err != nil { + return err + } + + if channel == nil || !channel.IsShared() { + return nil + } + + creator, err := getUserFromEvent(s, event, "creator_id") + if err != nil { + return err + } + // This is a termination condition, since on the other end when we are processing + // the invite we are re-triggering a model.WEBSOCKET_EVENT_DIRECT_ADDED, which will call this handler. + // When the creator is remote, it means that this is a DM that was not originated from the current server + // and therefore we do not need to do anything. + if creator == nil || creator.IsRemote() { + return nil + } + + participant, err := getUserFromEvent(s, event, "teammate_id") + if err != nil { + return err + } + + if participant == nil { + return nil + } + + rc, err := s.Store.RemoteCluster().Get(*participant.RemoteId) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("couldn't find remote cluster %s, for creating shared channel invitation for a DM", *participant.RemoteId)) + } + + return syncService.SendChannelInvite(channel, creator.Id, "", rc, sharedchannel.WithDirectParticipantID(creator.Id), sharedchannel.WithDirectParticipantID(participant.Id)) +} + +func getUserFromEvent(s *Server, event *model.WebSocketEvent, key string) (*model.User, error) { + userID, ok := event.GetData()[key].(string) + if !ok || userID == "" { + return nil, fmt.Errorf("received websocket message that is eligible for sending an invitation but message does not have `%s` present", key) + } + + user, err := s.Store.User().Get(context.Background(), userID) + if err != nil { + return nil, errors.Wrap(err, "couldn't find user for creating shared channel invitation for a DM") + } + + return user, nil +} + +func findChannel(server *Server, channelId string) (*model.Channel, error) { + channel, err := server.Store.Channel().Get(channelId, true) + if err != nil { + return nil, errors.Wrap(err, "received websocket message that is eligible for shared channel sync but channel does not exist") + } + + return channel, nil +} diff --git a/app/shared_channel_notifier_test.go b/app/shared_channel_notifier_test.go new file mode 100644 index 00000000000..db75af42c0e --- /dev/null +++ b/app/shared_channel_notifier_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func TestServerSyncSharedChannelHandler(t *testing.T) { + t.Run("sync service inactive, it does nothing", func(t *testing.T) { + th := SetupWithStoreMock(t) + defer th.TearDown() + + mockService := NewMockSharedChannelService(nil) + mockService.active = false + th.App.srv.sharedChannelService = mockService + + th.App.srv.SharedChannelSyncHandler(&model.WebSocketEvent{}) + assert.Empty(t, mockService.notifications) + }) + + t.Run("sync service active and broadcast envelope has ineligible event, it does nothing", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + mockService := NewMockSharedChannelService(nil) + mockService.active = true + th.App.srv.sharedChannelService = mockService + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + + websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_ADDED_TO_TEAM, model.NewId(), channel.Id, "", nil) + + th.App.srv.SharedChannelSyncHandler(websocketEvent) + assert.Empty(t, mockService.notifications) + }) + + t.Run("sync service active and broadcast envelope has eligible event but channel does not exist, it does nothing", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + mockService := NewMockSharedChannelService(nil) + mockService.active = true + th.App.srv.sharedChannelService = mockService + + websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POSTED, model.NewId(), model.NewId(), "", nil) + + th.App.srv.SharedChannelSyncHandler(websocketEvent) + assert.Empty(t, mockService.notifications) + }) + + t.Run("sync service active when received eligible event, it triggers a shared channel content sync", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + mockService := NewMockSharedChannelService(nil) + mockService.active = true + th.App.srv.sharedChannelService = mockService + + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POSTED, model.NewId(), channel.Id, "", nil) + + th.App.srv.SharedChannelSyncHandler(websocketEvent) + assert.Len(t, mockService.notifications, 1) + assert.Equal(t, channel.Id, mockService.notifications[0]) + }) +} diff --git a/app/shared_channel_service_iface.go b/app/shared_channel_service_iface.go new file mode 100644 index 00000000000..fd2ae1e6e75 --- /dev/null +++ b/app/shared_channel_service_iface.go @@ -0,0 +1,66 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/sharedchannel" +) + +// SharedChannelServiceIFace is the interface to the shared channel service +type SharedChannelServiceIFace interface { + Shutdown() error + Start() error + NotifyChannelChanged(channelId string) + SendChannelInvite(channel *model.Channel, userId string, description string, rc *model.RemoteCluster, options ...sharedchannel.InviteOption) error + Active() bool +} + +type MockOptionSharedChannelService func(service *mockSharedChannelService) + +func MockOptionSharedChannelServiceWithActive(active bool) MockOptionSharedChannelService { + return func(mrcs *mockSharedChannelService) { + mrcs.active = active + } +} + +func NewMockSharedChannelService(service SharedChannelServiceIFace, options ...MockOptionSharedChannelService) *mockSharedChannelService { + mrcs := &mockSharedChannelService{service, true, []string{}, 0} + for _, option := range options { + option(mrcs) + } + return mrcs +} + +type mockSharedChannelService struct { + SharedChannelServiceIFace + active bool + notifications []string + numInvitations int +} + +func (mrcs *mockSharedChannelService) NotifyChannelChanged(channelId string) { + mrcs.notifications = append(mrcs.notifications, channelId) +} + +func (mrcs *mockSharedChannelService) Shutdown() error { + return nil +} + +func (mrcs *mockSharedChannelService) Start() error { + return nil +} + +func (mrcs *mockSharedChannelService) Active() bool { + return mrcs.active +} + +func (mrcs *mockSharedChannelService) SendChannelInvite(channel *model.Channel, userId string, description string, rc *model.RemoteCluster, options ...sharedchannel.InviteOption) error { + mrcs.numInvitations += 1 + return nil +} + +func (mrcs *mockSharedChannelService) NumInvitations() int { + return mrcs.numInvitations +} diff --git a/app/shared_channel_test.go b/app/shared_channel_test.go new file mode 100644 index 00000000000..907983a3b7c --- /dev/null +++ b/app/shared_channel_test.go @@ -0,0 +1,89 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func TestApp_CheckCanInviteToSharedChannel(t *testing.T) { + th := Setup(t).InitBasic() + + channel1 := th.CreateChannel(th.BasicTeam) + channel2 := th.CreateChannel(th.BasicTeam) + channel3 := th.CreateChannel(th.BasicTeam) + + data := []struct { + channelId string + home bool + name string + remoteId string + }{ + {channelId: channel1.Id, home: true, name: "test_home", remoteId: ""}, + {channelId: channel2.Id, home: false, name: "test_remote", remoteId: model.NewId()}, + } + + for _, d := range data { + sc := &model.SharedChannel{ + ChannelId: d.channelId, + TeamId: th.BasicTeam.Id, + Home: d.home, + ShareName: d.name, + CreatorId: th.BasicUser.Id, + RemoteId: d.remoteId, + } + _, err := th.App.SaveSharedChannel(sc) + require.NoError(t, err) + } + + t.Run("Test checkChannelNotShared: not yet shared channel", func(t *testing.T) { + err := th.App.checkChannelNotShared(channel3.Id) + assert.NoError(t, err, "unshared channel should not error") + }) + + t.Run("Test checkChannelNotShared: already shared channel", func(t *testing.T) { + err := th.App.checkChannelNotShared(channel1.Id) + assert.Error(t, err, "already shared channel should error") + }) + + t.Run("Test checkChannelNotShared: invalid channel", func(t *testing.T) { + err := th.App.checkChannelNotShared(model.NewId()) + assert.Error(t, err, "invalid channel should error") + }) + + t.Run("Test checkChannelIsShared: not yet shared channel", func(t *testing.T) { + err := th.App.checkChannelIsShared(channel3.Id) + assert.Error(t, err, "unshared channel should error") + }) + + t.Run("Test checkChannelIsShared: already shared channel", func(t *testing.T) { + err := th.App.checkChannelIsShared(channel1.Id) + assert.NoError(t, err, "already channel should not error") + }) + + t.Run("Test checkChannelIsShared: invalid channel", func(t *testing.T) { + err := th.App.checkChannelIsShared(model.NewId()) + assert.Error(t, err, "invalid channel should error") + }) + + t.Run("Test CheckCanInviteToSharedChannel: Home shared channel", func(t *testing.T) { + err := th.App.CheckCanInviteToSharedChannel(data[0].channelId) + assert.NoError(t, err, "home channel should allow invites") + }) + + t.Run("Test CheckCanInviteToSharedChannel: Remote shared channel", func(t *testing.T) { + err := th.App.CheckCanInviteToSharedChannel(data[1].channelId) + assert.Error(t, err, "home channel should not allow invites") + }) + + t.Run("Test CheckCanInviteToSharedChannel: Invalid shared channel", func(t *testing.T) { + err := th.App.CheckCanInviteToSharedChannel(model.NewId()) + assert.Error(t, err, "invalid channel should not allow invites") + }) +} diff --git a/app/slashcommands/command_remote.go b/app/slashcommands/command_remote.go new file mode 100644 index 00000000000..d0b0976fd74 --- /dev/null +++ b/app/slashcommands/command_remote.go @@ -0,0 +1,292 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package slashcommands + +import ( + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/mattermost/mattermost-server/v5/app" + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/i18n" +) + +const ( + AvailableRemoteActions = "invite, accept, remove, status" +) + +type RemoteProvider struct { +} + +const ( + CommandTriggerRemote = "remote" +) + +func init() { + app.RegisterCommandProvider(&RemoteProvider{}) +} + +func (rp *RemoteProvider) GetTrigger() string { + return CommandTriggerRemote +} + +func (rp *RemoteProvider) GetCommand(a *app.App, T i18n.TranslateFunc) *model.Command { + + remote := model.NewAutocompleteData(rp.GetTrigger(), "[action]", T("api.command_remote.remote_add_remove.help", map[string]interface{}{"Actions": AvailableRemoteActions})) + + invite := model.NewAutocompleteData("invite", "", T("api.command_remote.invite.help")) + invite.AddNamedTextArgument("password", T("api.command_remote.invite_password.help"), T("api.command_remote.invite_password.hint"), "", true) + invite.AddNamedTextArgument("name", T("api.command_remote.name.help"), T("api.command_remote.name.hint"), "", true) + + accept := model.NewAutocompleteData("accept", "", T("api.command_remote.accept.help")) + accept.AddNamedTextArgument("password", T("api.command_remote.invite_password.help"), T("api.command_remote.invite_password.hint"), "", true) + accept.AddNamedTextArgument("name", T("api.command_remote.name.help"), T("api.command_remote.name.hint"), "", true) + accept.AddNamedTextArgument("invite", T("api.command_remote.invitation.help"), T("api.command_remote.invitation.hint"), "", true) + + remove := model.NewAutocompleteData("remove", "", T("api.command_remote.remove.help")) + remove.AddNamedDynamicListArgument("remoteId", T("api.command_remote.remove_remote_id.help"), "builtin:remote", true) + + status := model.NewAutocompleteData("status", "", T("api.command_remote.status.help")) + + remote.AddCommand(invite) + remote.AddCommand(accept) + remote.AddCommand(remove) + remote.AddCommand(status) + + return &model.Command{ + Trigger: rp.GetTrigger(), + AutoComplete: true, + AutoCompleteDesc: T("api.command_remote.desc"), + AutoCompleteHint: T("api.command_remote.hint"), + DisplayName: T("api.command_remote.name"), + AutocompleteData: remote, + } +} + +func (rp *RemoteProvider) DoCommand(a *app.App, args *model.CommandArgs, message string) *model.CommandResponse { + if !a.HasPermissionTo(args.UserId, model.PERMISSION_MANAGE_SHARED_CHANNELS) { + return responsef(args.T("api.command_remote.permission_required", map[string]interface{}{"Permission": "manage_shared_channels"})) + } + + margs := parseNamedArgs(args.Command) + action, ok := margs[ActionKey] + if !ok { + return responsef(args.T("api.command_remote.missing_command", map[string]interface{}{"Actions": AvailableRemoteActions})) + } + + switch action { + case "invite": + return rp.doInvite(a, args, margs) + case "accept": + return rp.doAccept(a, args, margs) + case "remove": + return rp.doRemove(a, args, margs) + case "status": + return rp.doStatus(a, args, margs) + } + + return responsef(args.T("api.command_remote.unknown_action", map[string]interface{}{"Action": action})) +} + +func (rp *RemoteProvider) GetAutoCompleteListItems(a *app.App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg, parsed, toBeParsed string) ([]model.AutocompleteListItem, error) { + if !a.HasPermissionTo(commandArgs.UserId, model.PERMISSION_MANAGE_SHARED_CHANNELS) { + return nil, errors.New("You require `manage_shared_channels` permission to manage remote clusters.") + } + + if arg.Name == "remoteId" && strings.Contains(parsed, " remove ") { + return getRemoteClusterAutocompleteListItems(a, true) + } + + return nil, fmt.Errorf("`%s` is not a dynamic argument", arg.Name) +} + +// doInvite creates and displays an encrypted invite that can be used by a remote site to establish a simple trust. +func (rp *RemoteProvider) doInvite(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + password := margs["password"] + if password == "" { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "password"})) + } + + name := margs["name"] + if name == "" { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "name"})) + } + + url := a.GetSiteURL() + if url == "" { + return responsef(args.T("api.command_remote.site_url_not_set")) + } + + rc := &model.RemoteCluster{ + DisplayName: name, + Token: model.NewId(), + CreatorId: args.UserId, + } + + rcSaved, appErr := a.AddRemoteCluster(rc) + if appErr != nil { + return responsef(args.T("api.command_remote.add_remote.error", map[string]interface{}{"Error": appErr.Error()})) + } + + // Display the encrypted invitation + invite := &model.RemoteClusterInvite{ + RemoteId: rcSaved.RemoteId, + RemoteTeamId: args.TeamId, + SiteURL: url, + Token: rcSaved.Token, + } + encrypted, err := invite.Encrypt(password) + if err != nil { + return responsef(args.T("api.command_remote.encrypt_invitation.error", map[string]interface{}{"Error": err.Error()})) + } + encoded := base64.URLEncoding.EncodeToString(encrypted) + + return responsef("##### " + args.T("api.command_remote.invitation_created") + "\n" + + args.T("api.command_remote.invite_summary", map[string]interface{}{"Command": "/remote accept", "Invitation": encoded, "SiteURL": invite.SiteURL})) +} + +// doAccept accepts an invitation generated by a remote site. +func (rp *RemoteProvider) doAccept(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + password := margs["password"] + if password == "" { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "password"})) + } + + name := margs["name"] + if name == "" { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "name"})) + } + + blob := margs["invite"] + if blob == "" { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "invite"})) + } + + // invite is encoded as base64 and encrypted + decoded, err := base64.URLEncoding.DecodeString(blob) + if err != nil { + return responsef(args.T("api.command_remote.decode_invitation.error", map[string]interface{}{"Error": err.Error()})) + } + invite := &model.RemoteClusterInvite{} + err = invite.Decrypt(decoded, password) + if err != nil { + return responsef(args.T("api.command_remote.incorrect_password.error", map[string]interface{}{"Error": err.Error()})) + } + + rcs, _ := a.GetRemoteClusterService() + if rcs == nil { + return responsef(args.T("api.command_remote.service_not_enabled")) + } + + url := a.GetSiteURL() + if url == "" { + return responsef(args.T("api.command_remote.site_url_not_set")) + } + + rc, err := rcs.AcceptInvitation(invite, name, args.UserId, args.TeamId, url) + if err != nil { + return responsef(args.T("api.command_remote.accept_invitation.error", map[string]interface{}{"Error": err.Error()})) + } + + return responsef("##### " + args.T("api.command_remote.accept_invitation", map[string]interface{}{"SiteURL": rc.SiteURL})) +} + +// doRemove removes a remote cluster from the database, effectively revoking the trust relationship. +func (rp *RemoteProvider) doRemove(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + id, ok := margs["remoteId"] + if !ok { + return responsef(args.T("api.command_remote.missing_empty", map[string]interface{}{"Arg": "remoteId"})) + } + + deleted, err := a.DeleteRemoteCluster(id) + if err != nil { + responsef(args.T("api.command_remote.remove_remote.error", map[string]interface{}{"Error": err.Error()})) + } + + result := "removed" + if !deleted { + result = "**NOT FOUND**" + } + return responsef("##### " + args.T("api.command_remote.cluster_removed", map[string]interface{}{"RemoteId": id, "Result": result})) +} + +// doStatus displays connection status for all remote clusters. +func (rp *RemoteProvider) doStatus(a *app.App, args *model.CommandArgs, _ map[string]string) *model.CommandResponse { + list, err := a.GetAllRemoteClusters(model.RemoteClusterQueryFilter{}) + if err != nil { + responsef(args.T("api.command_remote.fetch_status.error", map[string]interface{}{"Error": err.Error()})) + } + + if len(list) == 0 { + return responsef("** " + args.T("api.command_remote.remotes_not_found") + " **") + } + + var sb strings.Builder + fmt.Fprintf(&sb, args.T("api.command_remote.remote_table_header")+"| \n") + fmt.Fprintf(&sb, "| ---- | -------- | ---------- | :-------------: | :----: | ---------- |\n") + + for _, rc := range list { + accepted := ":white_check_mark:" + if rc.SiteURL == "" { + accepted = ":x:" + } + + online := ":white_check_mark:" + if !isOnline(rc.LastPingAt) { + online = ":skull_and_crossbones:" + } + + lastPing := formatTimestamp(model.GetTimeForMillis(rc.LastPingAt)) + + fmt.Fprintf(&sb, "| %s | %s | %s | %s | %s | %s |\n", rc.DisplayName, rc.SiteURL, rc.RemoteId, accepted, online, lastPing) + } + return responsef(sb.String()) +} + +func isOnline(lastPing int64) bool { + return lastPing > model.GetMillis()-model.RemoteOfflineAfterMillis +} + +func getRemoteClusterAutocompleteListItems(a *app.App, includeOffline bool) ([]model.AutocompleteListItem, error) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: !includeOffline, + } + clusters, err := a.GetAllRemoteClusters(filter) + if err != nil || len(clusters) == 0 { + return []model.AutocompleteListItem{}, nil + } + + list := make([]model.AutocompleteListItem, 0, len(clusters)) + + for _, rc := range clusters { + item := model.AutocompleteListItem{ + Item: rc.RemoteId, + HelpText: fmt.Sprintf("%s (%s)", rc.DisplayName, rc.SiteURL)} + list = append(list, item) + } + return list, nil +} + +func getRemoteClusterAutocompleteListItemsNotInChannel(a *app.App, channelId string, includeOffline bool) ([]model.AutocompleteListItem, error) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: !includeOffline, + NotInChannel: channelId, + } + all, err := a.GetAllRemoteClusters(filter) + if err != nil || len(all) == 0 { + return []model.AutocompleteListItem{}, nil + } + + list := make([]model.AutocompleteListItem, 0, len(all)) + + for _, rc := range all { + item := model.AutocompleteListItem{ + Item: rc.RemoteId, + HelpText: fmt.Sprintf("%s (%s)", rc.DisplayName, rc.SiteURL)} + list = append(list, item) + } + return list, nil +} diff --git a/app/slashcommands/command_share.go b/app/slashcommands/command_share.go new file mode 100644 index 00000000000..2bbcc3c13d4 --- /dev/null +++ b/app/slashcommands/command_share.go @@ -0,0 +1,347 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package slashcommands + +import ( + "errors" + "fmt" + "strings" + + "github.com/mattermost/mattermost-server/v5/app" + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/i18n" +) + +type ShareProvider struct { +} + +const ( + CommandTriggerShare = "share" + AvailableShareActions = "share_channel, unshare_channel, invite_remove, uninvite_remote, status" +) + +func init() { + app.RegisterCommandProvider(&ShareProvider{}) +} + +func (sp *ShareProvider) GetTrigger() string { + return CommandTriggerShare +} + +func (sp *ShareProvider) GetCommand(a *app.App, T i18n.TranslateFunc) *model.Command { + share := model.NewAutocompleteData(CommandTriggerShare, "[action]", T("api.command_share.available_actions", map[string]interface{}{"Actions": AvailableShareActions})) + + shareChannel := model.NewAutocompleteData("share_channel", "", T("api.command_share.share_current")) + shareChannel.AddNamedTextArgument("readonly", T("api.command_share.share_read_only.help"), T("api.command_share.share_read_only.hint"), "Y|N|y|n", false) + shareChannel.AddNamedTextArgument("name", T("api.command_share.channel_name.help"), T("api.command_share.channel_name.hint"), "", false) + shareChannel.AddNamedTextArgument("displayname", T("api.command_share.channel_display_name.help"), T("api.command_share.channel_display_name.hint"), "", false) + shareChannel.AddNamedTextArgument("purpose", T("api.command_share.channel_purpose.help"), T("api.command_share.channel_purpose.hint"), "", false) + shareChannel.AddNamedTextArgument("header", T("api.command_share.channel_header.help"), T("api.command_share.channel_header.hint"), "", false) + + unshareChannel := model.NewAutocompleteData("unshare_channel", "", T("api.command_share.unshare_channel.help")) + unshareChannel.AddNamedTextArgument("are_you_sure", T("api.command_share.unshare_confirmation.help"), T("api.command_share.unshare_confirmation.hint"), "Y|N|y|n", true) + + inviteRemote := model.NewAutocompleteData("invite_remote", "", T("api.command_share.invite_remote.help")) + inviteRemote.AddNamedDynamicListArgument("remoteId", T("api.command_share.remote_id.help"), "builtin:share", true) + inviteRemote.AddNamedTextArgument("description", T("api.command_share.description_invite.help"), T("api.command_share.description_invite.hint"), "", false) + + unInviteRemote := model.NewAutocompleteData("uninvite_remote", "", T("api.command_share.uninvite_remote.help")) + unInviteRemote.AddNamedDynamicListArgument("remoteId", T("api.command_share.uninvite_remote_id.help"), "builtin:share", true) + + status := model.NewAutocompleteData("status", "", T("api.command_share.channel_status.help")) + + share.AddCommand(shareChannel) + share.AddCommand(unshareChannel) + share.AddCommand(inviteRemote) + share.AddCommand(unInviteRemote) + share.AddCommand(status) + + return &model.Command{ + Trigger: CommandTriggerShare, + AutoComplete: true, + AutoCompleteDesc: T("api.command_share.desc"), + AutoCompleteHint: T("api.command_share.hint"), + DisplayName: T("api.command_share.name"), + AutocompleteData: share, + } +} + +func (sp *ShareProvider) GetAutoCompleteListItems(a *app.App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg, parsed, toBeParsed string) ([]model.AutocompleteListItem, error) { + switch { + case strings.Contains(parsed, " share_channel "): + + return sp.getAutoCompleteShareChannel(a, commandArgs, arg) + + case strings.Contains(parsed, " invite_remote "): + + return sp.getAutoCompleteInviteRemote(a, commandArgs, arg) + + case strings.Contains(parsed, " uninvite_remote "): + + return sp.getAutoCompleteUnInviteRemote(a, commandArgs, arg) + + } + return nil, errors.New("invalid action") +} + +func (sp *ShareProvider) getAutoCompleteShareChannel(a *app.App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg) ([]model.AutocompleteListItem, error) { + channel, err := a.GetChannel(commandArgs.ChannelId) + if err != nil { + return nil, err + } + + var item model.AutocompleteListItem + + switch arg.Name { + case "name": + item = model.AutocompleteListItem{ + Item: channel.Name, + HelpText: channel.DisplayName, + } + case "displayname": + item = model.AutocompleteListItem{ + Item: channel.DisplayName, + HelpText: channel.Name, + } + default: + return nil, fmt.Errorf("%s not a dynamic argument", arg.Name) + } + return []model.AutocompleteListItem{item}, nil +} + +func (sp *ShareProvider) getAutoCompleteInviteRemote(a *app.App, commandArgs *model.CommandArgs, arg *model.AutocompleteArg) ([]model.AutocompleteListItem, error) { + switch arg.Name { + case "remoteId": + return getRemoteClusterAutocompleteListItemsNotInChannel(a, commandArgs.ChannelId, true) + default: + return nil, fmt.Errorf("%s not a dynamic argument", arg.Name) + } +} + +func (sp *ShareProvider) getAutoCompleteUnInviteRemote(a *app.App, _ *model.CommandArgs, arg *model.AutocompleteArg) ([]model.AutocompleteListItem, error) { + switch arg.Name { + case "remoteId": + return getRemoteClusterAutocompleteListItems(a, true) + default: + return nil, fmt.Errorf("%s not a dynamic argument", arg.Name) + } +} + +func (sp *ShareProvider) DoCommand(a *app.App, args *model.CommandArgs, message string) *model.CommandResponse { + if !a.HasPermissionTo(args.UserId, model.PERMISSION_MANAGE_SHARED_CHANNELS) { + return responsef(args.T("api.command_share.permission_required", map[string]interface{}{"Permission": "manage_shared_channels"})) + } + + if a.Srv().GetSharedChannelSyncService() == nil { + return responsef(args.T("api.command_share.service_disabled")) + } + + if a.Srv().GetRemoteClusterService() == nil { + return responsef(args.T("api.command_remote.service_disabled")) + } + + margs := parseNamedArgs(args.Command) + action, ok := margs[ActionKey] + if !ok { + return responsef(args.T("api.command_share.missing_action", map[string]interface{}{"Actions": AvailableShareActions})) + } + + switch action { + case "share_channel": + return sp.doShareChannel(a, args, margs) + case "unshare_channel": + return sp.doUnshareChannel(a, args, margs) + case "invite_remote": + return sp.doInviteRemote(a, args, margs) + case "uninvite_remote": + return sp.doUninviteRemote(a, args, margs) + case "status": + return sp.doStatus(a, args, margs) + } + return responsef(args.T("api.command_share.unknown_action", map[string]interface{}{"Action": action, "Actions": AvailableShareActions})) +} + +func (sp *ShareProvider) doShareChannel(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + // check that channel exists. + channel, errApp := a.GetChannel(args.ChannelId) + if errApp != nil { + return responsef(args.T("api.command_share.share_channel.error", map[string]interface{}{"Error": errApp.Error()})) + } + + if name := margs["name"]; name == "" { + margs["name"] = channel.Name + } + if name := margs["displayname"]; name == "" { + margs["displayname"] = channel.DisplayName + } + if name := margs["purpose"]; name == "" { + margs["purpose"] = channel.Purpose + } + if name := margs["header"]; name == "" { + margs["header"] = channel.Header + } + if _, ok := margs["readonly"]; !ok { + margs["readonly"] = "N" + } + + readonly, err := parseBool(margs["readonly"]) + if err != nil { + return responsef(args.T("api.command_share.invalid_value.error", map[string]interface{}{"Arg": "readonly", "Error": err.Error()})) + } + + sc := &model.SharedChannel{ + ChannelId: args.ChannelId, + TeamId: args.TeamId, + Home: true, + ReadOnly: readonly, + ShareName: margs["name"], + ShareDisplayName: margs["displayname"], + SharePurpose: margs["purpose"], + ShareHeader: margs["header"], + CreatorId: args.UserId, + } + + if _, err := a.SaveSharedChannel(sc); err != nil { + return responsef(args.T("api.command_share.share_channel.error", map[string]interface{}{"Error": err.Error()})) + } + + notifyClientsForChannelUpdate(a, sc) + + return responsef("##### " + args.T("api.command_share.channel_shared")) +} + +func (sp *ShareProvider) doUnshareChannel(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + if _, ok := margs["are_you_sure"]; !ok { + margs["are_you_sure"] = "N" + } + + sure, err := parseBool(margs["are_you_sure"]) + if err != nil || !sure { + return responsef(args.T("api.command_share.shared_channel_not_deleted", map[string]interface{}{"Arg": "are_you_sure", "Expected": "Y"})) + } + + sc, appErr := a.GetSharedChannel(args.ChannelId) + if appErr != nil { + return responsef(args.T("api.command_share.shared_channel_unshare.error", map[string]interface{}{"Error": appErr.Error()})) + } + + deleted, err := a.DeleteSharedChannel(args.ChannelId) + if err != nil { + return responsef(args.T("api.command_share.shared_channel_unshare.error", map[string]interface{}{"Error": err.Error()})) + } + if !deleted { + return responsef(args.T("api.command_share.not_shared_channel_unshare")) + } + + notifyClientsForChannelUpdate(a, sc) + + return responsef("##### " + args.T("api.command_share.shared_channel_unavailable")) +} + +func (sp *ShareProvider) doInviteRemote(a *app.App, args *model.CommandArgs, margs map[string]string) (resp *model.CommandResponse) { + remoteId, ok := margs["remoteId"] + if !ok || remoteId == "" { + return responsef(args.T("api.command_share.must_specify_valid_remote")) + } + + hasRemote, err := a.HasRemote(args.ChannelId, remoteId) + if err != nil { + return responsef(args.T("api.command_share.fetch_remote.error", map[string]interface{}{"Error": err.Error()})) + } + if hasRemote { + return responsef(args.T("api.command_share.remote_already_invited")) + } + + // Check if channel is shared or not. + hasChan, err := a.HasSharedChannel(args.ChannelId) + if err != nil { + return responsef(args.T("api.command_share.check_channel_exist.error", map[string]interface{}{"Error": err.Error()})) + } + if !hasChan { + // If it doesn't exist, then create it. + resp2 := sp.doShareChannel(a, args, margs) + // We modify the outgoing response by prepending the text + // from the shareChannel response. + defer func() { + resp.Text = resp2.Text + "\n" + resp.Text + }() + } + + // don't allow invitation to shared channel originating from remote. + // (also blocks cyclic invitations) + if err := a.CheckCanInviteToSharedChannel(args.ChannelId); err != nil { + return responsef(args.T("api.command_share.channel_invite_not_home.error")) + } + + rc, appErr := a.GetRemoteCluster(remoteId) + if appErr != nil { + return responsef(args.T("api.command_share.remote_id_invalid.error", map[string]interface{}{"Error": appErr.Error()})) + } + + channel, errApp := a.GetChannel(args.ChannelId) + if errApp != nil { + return responsef(args.T("api.command_share.channel_invite.error", map[string]interface{}{"Name": rc.DisplayName, "Error": errApp.Error()})) + } + // send channel invite to remote cluster + if err := a.Srv().GetSharedChannelSyncService().SendChannelInvite(channel, args.UserId, margs["description"], rc); err != nil { + return responsef(args.T("api.command_share.channel_invite.error", map[string]interface{}{"Name": rc.DisplayName, "Error": err.Error()})) + } + + return responsef("##### " + args.T("api.command_share.invitation_sent", map[string]interface{}{"Name": rc.DisplayName, "SiteURL": rc.SiteURL})) +} + +func (sp *ShareProvider) doUninviteRemote(a *app.App, args *model.CommandArgs, margs map[string]string) *model.CommandResponse { + remoteId, ok := margs["remoteId"] + if !ok || remoteId == "" { + return responsef(args.T("api.command_share.remote_not_valid")) + } + + scr, err := a.GetSharedChannelRemoteByIds(args.ChannelId, remoteId) + if err != nil || scr.ChannelId != args.ChannelId { + return responsef(args.T("api.command_share.channel_remote_id_not_exists", map[string]interface{}{"RemoteId": remoteId})) + } + + deleted, err := a.DeleteSharedChannelRemote(scr.Id) + if err != nil || !deleted { + return responsef(args.T("api.command_share.could_not_uninvite.error", map[string]interface{}{"RemoteId": remoteId, "Error": err.Error()})) + } + return responsef("##### " + args.T("api.command_share.remote_uninvited", map[string]interface{}{"RemoteId": remoteId})) +} + +func (sp *ShareProvider) doStatus(a *app.App, args *model.CommandArgs, _ map[string]string) *model.CommandResponse { + statuses, err := a.GetSharedChannelRemotesStatus(args.ChannelId) + if err != nil { + return responsef(args.T("api.command_share.fetch_remote_status.error", map[string]interface{}{"Error": err.Error()})) + } + if len(statuses) == 0 { + return responsef(args.T("api.command_share.no_remote_invited")) + } + + var sb strings.Builder + + fmt.Fprintf(&sb, args.T("api.command_share.channel_status_id", map[string]interface{}{"ChannelId": statuses[0].ChannelId})+"\n\n") + + fmt.Fprintf(&sb, args.T("api.command_share.remote_table_header")+" \n") + fmt.Fprintf(&sb, "| ------ | ------- | ----------- | -------- | -------------- | ------ | --------- | \n") + + for _, status := range statuses { + online := ":white_check_mark:" + if !isOnline(status.LastPingAt) { + online = ":skull_and_crossbones:" + } + + lastSync := formatTimestamp(model.GetTimeForMillis(status.NextSyncAt)) + + fmt.Fprintf(&sb, "| %s | %s | %s | %t | %t | %s | %s |\n", + status.DisplayName, status.SiteURL, status.Description, + status.ReadOnly, status.IsInviteAccepted, online, lastSync) + } + return responsef(sb.String()) +} + +func notifyClientsForChannelUpdate(a *app.App, sharedChannel *model.SharedChannel) { + messageWs := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_CHANNEL_CONVERTED, sharedChannel.TeamId, "", "", nil) + messageWs.Add("channel_id", sharedChannel.ChannelId) + a.Publish(messageWs) +} diff --git a/app/slashcommands/command_share_test.go b/app/slashcommands/command_share_test.go new file mode 100644 index 00000000000..0517b8ca1f0 --- /dev/null +++ b/app/slashcommands/command_share_test.go @@ -0,0 +1,92 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package slashcommands + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost-server/v5/testlib" + + "github.com/mattermost/mattermost-server/v5/app" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func TestShareProviderDoCommand(t *testing.T) { + t.Run("share command sends a websocket channel converted event", func(t *testing.T) { + th := setup(t).initBasic() + defer th.tearDown() + + th.addPermissionToRole(model.PERMISSION_MANAGE_SHARED_CHANNELS.Id, th.BasicUser.Roles) + + mockSyncService := app.NewMockSharedChannelService(nil) + th.Server.SetSharedChannelSyncService(mockSyncService) + mockRemoteCluster, err := remotecluster.NewRemoteClusterService(th.Server) + require.NoError(t, err) + + th.Server.SetRemoteClusterService(mockRemoteCluster) + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + commandProvider := ShareProvider{} + channel := th.CreateChannel(th.BasicTeam, WithShared(false)) + args := &model.CommandArgs{ + T: func(s string, args ...interface{}) string { return s }, + ChannelId: channel.Id, + UserId: th.BasicUser.Id, + TeamId: th.BasicTeam.Id, + Command: "/share share_channel", + } + + response := commandProvider.DoCommand(th.App, args, "") + require.Equal(t, "##### "+args.T("api.command_share.channel_shared"), response.Text) + + channelConvertedMessages := testCluster.SelectMessages(func(msg *model.ClusterMessage) bool { + event := model.WebSocketEventFromJson(strings.NewReader(msg.Data)) + return event != nil && event.EventType() == model.WEBSOCKET_EVENT_CHANNEL_CONVERTED + }) + assert.Len(t, channelConvertedMessages, 1) + }) + + t.Run("unshare command sends a websocket channel converted event", func(t *testing.T) { + th := setup(t).initBasic() + defer th.tearDown() + + th.addPermissionToRole(model.PERMISSION_MANAGE_SHARED_CHANNELS.Id, th.BasicUser.Roles) + + mockSyncService := app.NewMockSharedChannelService(nil) + th.Server.SetSharedChannelSyncService(mockSyncService) + mockRemoteCluster, err := remotecluster.NewRemoteClusterService(th.Server) + require.NoError(t, err) + + th.Server.SetRemoteClusterService(mockRemoteCluster) + testCluster := &testlib.FakeClusterInterface{} + th.Server.Cluster = testCluster + + commandProvider := ShareProvider{} + channel := th.CreateChannel(th.BasicTeam, WithShared(true)) + args := &model.CommandArgs{ + T: func(s string, args ...interface{}) string { return s }, + ChannelId: channel.Id, + UserId: th.BasicUser.Id, + TeamId: th.BasicTeam.Id, + Command: "/share unshare_channel --are_you_sure Y", + } + + response := commandProvider.DoCommand(th.App, args, "") + require.Equal(t, "##### "+args.T("api.command_share.shared_channel_unavailable"), response.Text) + + channelConvertedMessages := testCluster.SelectMessages(func(msg *model.ClusterMessage) bool { + event := model.WebSocketEventFromJson(strings.NewReader(msg.Data)) + return event != nil && event.EventType() == model.WEBSOCKET_EVENT_CHANNEL_CONVERTED + }) + require.Len(t, channelConvertedMessages, 1) + }) +} diff --git a/app/slashcommands/helper_test.go b/app/slashcommands/helper_test.go index 58d86bbc1b8..28f2a4b3d09 100644 --- a/app/slashcommands/helper_test.go +++ b/app/slashcommands/helper_test.go @@ -226,15 +226,23 @@ func (th *TestHelper) createUserOrGuest(guest bool) *model.User { return user } -func (th *TestHelper) CreateChannel(team *model.Team) *model.Channel { - return th.createChannel(team, model.CHANNEL_OPEN) +type ChannelOption func(*model.Channel) + +func WithShared(v bool) ChannelOption { + return func(channel *model.Channel) { + channel.Shared = model.NewBool(v) + } +} + +func (th *TestHelper) CreateChannel(team *model.Team, options ...ChannelOption) *model.Channel { + return th.createChannel(team, model.CHANNEL_OPEN, options...) } func (th *TestHelper) createPrivateChannel(team *model.Team) *model.Channel { return th.createChannel(team, model.CHANNEL_PRIVATE) } -func (th *TestHelper) createChannel(team *model.Team, channelType string) *model.Channel { +func (th *TestHelper) createChannel(team *model.Team, channelType string, options ...ChannelOption) *model.Channel { id := model.NewId() channel := &model.Channel{ @@ -245,11 +253,32 @@ func (th *TestHelper) createChannel(team *model.Team, channelType string) *model CreatorId: th.BasicUser.Id, } + for _, option := range options { + option(channel) + } + utils.DisableDebugLogForTest() var err *model.AppError if channel, err = th.App.CreateChannel(channel, true); err != nil { panic(err) } + + if channel.IsShared() { + id := model.NewId() + _, err := th.App.SaveSharedChannel(&model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + Home: false, + ReadOnly: false, + ShareName: "shared-" + id, + ShareDisplayName: "shared-" + id, + CreatorId: th.BasicUser.Id, + RemoteId: model.NewId(), + }) + if err != nil { + panic(err) + } + } utils.EnableDebugLogForTest() return channel } diff --git a/app/slashcommands/util.go b/app/slashcommands/util.go new file mode 100644 index 00000000000..e7bb0462f96 --- /dev/null +++ b/app/slashcommands/util.go @@ -0,0 +1,88 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package slashcommands + +import ( + "fmt" + "strings" + "time" + + "github.com/mattermost/mattermost-server/v5/model" +) + +const ( + ActionKey = "-action" +) + +// responsef creates an ephemeral command response using printf syntax. +func responsef(format string, args ...interface{}) *model.CommandResponse { + return &model.CommandResponse{ + ResponseType: model.COMMAND_RESPONSE_TYPE_EPHEMERAL, + Text: fmt.Sprintf(format, args...), + Type: model.POST_DEFAULT, + } +} + +// parseNamedArgs parses a command string into a map of arguments. It is assumed the +// command string is of the form ` --arg1 value1 ...` Supports empty values. +// Arg names are limited to [0-9a-zA-Z_]. +func parseNamedArgs(cmd string) map[string]string { + m := make(map[string]string) + + split := strings.Fields(cmd) + + // check for optional action + if len(split) >= 2 && !strings.HasPrefix(split[1], "--") { + m[ActionKey] = split[1] // prefix with hyphen to avoid collision with arg named "action" + } + + for i := 0; i < len(split); i++ { + if !strings.HasPrefix(split[i], "--") { + continue + } + var val string + arg := trimSpaceAndQuotes(strings.Trim(split[i], "-")) + if i < len(split)-1 && !strings.HasPrefix(split[i+1], "--") { + val = trimSpaceAndQuotes(split[i+1]) + } + if arg != "" { + m[arg] = val + } + } + return m +} + +func trimSpaceAndQuotes(s string) string { + trimmed := strings.TrimSpace(s) + trimmed = strings.TrimPrefix(trimmed, "\"") + trimmed = strings.TrimPrefix(trimmed, "'") + trimmed = strings.TrimSuffix(trimmed, "\"") + trimmed = strings.TrimSuffix(trimmed, "'") + return trimmed +} + +func parseBool(s string) (bool, error) { + switch strings.ToLower(s) { + case "1", "t", "true", "yes", "y": + return true, nil + case "0", "f", "false", "no", "n": + return false, nil + } + return false, fmt.Errorf("cannot parse '%s' as a boolean", s) +} + +func formatTimestamp(ts time.Time) string { + if !isToday(ts) { + return ts.Format("Jan 2 15:04:05 MST 2006") + } + date := ts.Format("15:04:05 MST 2006") + return fmt.Sprintf("today %s", date) +} + +func isToday(ts time.Time) bool { + now := time.Now() + year, month, day := ts.Date() + nowYear, nowMonth, nowDay := now.Date() + return year == nowYear && month == nowMonth && day == nowDay +} diff --git a/app/slashcommands/util_test.go b/app/slashcommands/util_test.go new file mode 100644 index 00000000000..ad55494b99d --- /dev/null +++ b/app/slashcommands/util_test.go @@ -0,0 +1,40 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package slashcommands + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseNamedArgs(t *testing.T) { + data := []struct { + name string + s string + m map[string]string + }{ + {"empty", "", map[string]string{}}, + {"gibberish", "ifu3ue-h29f8", map[string]string{}}, + {"action only", "remote status", map[string]string{ActionKey: "status"}}, + {"no action", "remote --arg1 val1 --arg2 val2", map[string]string{"arg1": "val1", "arg2": "val2"}}, + {"command only", "remote", map[string]string{}}, + {"trailing empty arg", "remote add --arg1 val1 --arg2", map[string]string{ActionKey: "add", "arg1": "val1", "arg2": ""}}, + {"leading empty arg", "remote add --arg1 --arg2 val2", map[string]string{ActionKey: "add", "arg1": "", "arg2": "val2"}}, + {"weird", "-- -- -- --", map[string]string{}}, + {"hyphen before action", "remote -- add", map[string]string{}}, + {"trailing hyphen", "remote add -- ", map[string]string{ActionKey: "add"}}, + {"hyphen in val", "remote add --arg1 val-1 ", map[string]string{ActionKey: "add", "arg1": "val-1"}}, + {"quote prefix and suffix", "remote add --arg1 \"val-1\"", map[string]string{ActionKey: "add", "arg1": "val-1"}}, + {"quote embedded", "remote add --arg1 O'Brien", map[string]string{ActionKey: "add", "arg1": "O'Brien"}}, + {"quote prefix, suffix, and embedded", "remote add --arg1 \"O'Brien\"", map[string]string{ActionKey: "add", "arg1": "O'Brien"}}, + {"empty quotes", "remote add --arg1 \"\"", map[string]string{ActionKey: "add", "arg1": ""}}, + } + + for _, tt := range data { + m := parseNamedArgs(tt.s) + assert.NotNil(t, m) + assert.Equal(t, tt.m, m, tt.name) + } +} diff --git a/app/upload.go b/app/upload.go index 65e08a746c1..d50397d1a0a 100644 --- a/app/upload.go +++ b/app/upload.go @@ -253,6 +253,10 @@ func (a *App) UploadData(us *model.UploadSession, rd io.Reader) (*model.FileInfo info.CreatorId = us.UserId info.Path = us.Path + info.RemoteId = model.NewString(us.RemoteId) + if us.ReqFileId != "" { + info.Id = us.ReqFileId + } // run plugins upload hook if err := a.runPluginsHook(info, file); err != nil { diff --git a/app/web_hub.go b/app/web_hub.go index fbdddc3e624..b11f16cab82 100644 --- a/app/web_hub.go +++ b/app/web_hub.go @@ -180,17 +180,20 @@ func (a *App) Publish(message *model.WebSocketEvent) { a.Srv().Publish(message) } -func (s *Server) PublishSkipClusterSend(message *model.WebSocketEvent) { - if message.GetBroadcast().UserId != "" { - hub := s.GetHubForUserId(message.GetBroadcast().UserId) +func (s *Server) PublishSkipClusterSend(event *model.WebSocketEvent) { + if event.GetBroadcast().UserId != "" { + hub := s.GetHubForUserId(event.GetBroadcast().UserId) if hub != nil { - hub.Broadcast(message) + hub.Broadcast(event) } } else { for _, hub := range s.hubs { - hub.Broadcast(message) + hub.Broadcast(event) } } + + // Notify shared channel sync service + s.SharedChannelSyncHandler(event) } func (a *App) invalidateCacheForChannel(channel *model.Channel) { diff --git a/build/docker/nginx/default.conf b/build/docker/nginx/default.conf index b1ca7ac4f2b..4eda6e7fb06 100644 --- a/build/docker/nginx/default.conf +++ b/build/docker/nginx/default.conf @@ -1,6 +1,7 @@ upstream app_cluster { - server leader:8065 fail_timeout=5s max_fails=10; - server follower:8065 fail_timeout=5s max_fails=10; + server leader:8065 fail_timeout=10s max_fails=10; + server follower:8065 fail_timeout=10s max_fails=10; + server follower2:8065 fail_timeout=10s max_fails=10; } server { @@ -9,6 +10,7 @@ server { location ~ /api/v[0-9]+/(users/)?websocket$ { proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; + proxy_http_version 1.1; client_max_body_size 50M; proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; @@ -25,6 +27,7 @@ server { client_max_body_size 50M; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; + proxy_http_version 1.1; proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; diff --git a/config/client.go b/config/client.go index 0479da61549..5d0acb7d9ca 100644 --- a/config/client.go +++ b/config/client.go @@ -205,6 +205,7 @@ func GenerateClientConfig(c *model.Config, telemetryID string, license *model.Li if *license.Features.SharedChannels { props["ExperimentalSharedChannels"] = strconv.FormatBool(*c.ExperimentalSettings.EnableSharedChannels) + props["ExperimentalRemoteClusterService"] = strconv.FormatBool(c.FeatureFlags.EnableRemoteClusterService && *c.ExperimentalSettings.EnableRemoteClusterService) } } diff --git a/docker-compose.yaml b/docker-compose.yaml index 3e95605b27c..43b3f7140c3 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -106,6 +106,7 @@ services: - "MM_SQLSETTINGS_DATASOURCE=postgres://mmuser:mostest@postgres/mattermost_test?sslmode=disable\u0026connect_timeout=10" - "MM_NO_DOCKER=true" - "RUN_SERVER_IN_BACKGROUND=false" + - "MM_CLUSTERSETTINGS_ENABLE=true" networks: - mm-test depends_on: @@ -118,11 +119,18 @@ services: healthcheck: test: ["CMD", "curl", "-f", "http://leader:8065/api/v4/system/ping"] interval: 5s - timeout: 10s + timeout: 30s retries: 30 + start_period: 5m + user: ${CURRENT_UID} command: ['make', 'run-server'] expose: - "8065" + - "8064/tcp" + - "8064/udp" + - "8074/tcp" + - "8074/udp" + - "8075" follower: build: @@ -133,6 +141,7 @@ services: - "MM_SQLSETTINGS_DATASOURCE=postgres://mmuser:mostest@postgres/mattermost_test?sslmode=disable\u0026connect_timeout=10" - "MM_NO_DOCKER=true" - "RUN_SERVER_IN_BACKGROUND=false" + - "MM_CLUSTERSETTINGS_ENABLE=true" networks: - mm-test depends_on: @@ -144,12 +153,54 @@ services: healthcheck: test: ["CMD", "curl", "-f", "http://follower:8065/api/v4/system/ping"] interval: 5s - timeout: 10s + timeout: 30s retries: 30 + start_period: 5m + user: ${CURRENT_UID} command: ['make', 'run-server'] restart: on-failure expose: - "8065" + - "8064/tcp" + - "8064/udp" + - "8074/tcp" + - "8074/udp" + - "8075" + + follower2: + build: + context: . + dockerfile: ./build/Dockerfile.buildenv + working_dir: '/home/mattermost-server' + environment: + - "MM_SQLSETTINGS_DATASOURCE=postgres://mmuser:mostest@postgres/mattermost_test?sslmode=disable\u0026connect_timeout=10" + - "MM_NO_DOCKER=true" + - "RUN_SERVER_IN_BACKGROUND=false" + - "MM_CLUSTERSETTINGS_ENABLE=true" + networks: + - mm-test + depends_on: + - leader + volumes: + - './:/home/mattermost-server' + - './../mattermost-webapp:/home/mattermost-webapp' + - './../enterprise:/home/enterprise' + healthcheck: + test: ["CMD", "curl", "-f", "http://follower2:8065/api/v4/system/ping"] + interval: 5s + timeout: 30s + retries: 30 + start_period: 5m + user: ${CURRENT_UID} + command: ['make', 'run-server'] + restart: on-failure + expose: + - "8065" + - "8064/tcp" + - "8064/udp" + - "8074/tcp" + - "8074/udp" + - "8075" haproxy: image: nginx @@ -159,8 +210,12 @@ services: - ./build/docker/nginx/default.conf:/etc/nginx/conf.d/default.conf restart: on-failure depends_on: - - leader - - follower + leader: + condition: service_healthy + follower: + condition: service_healthy + follower2: + condition: service_healthy ports: - "8065:8065" @@ -171,4 +226,4 @@ networks: driver: default config: - subnet: 192.168.254.0/24 - ip_range: 192.168.254.0/24 + ip_range: 192.168.254.0/24 \ No newline at end of file diff --git a/einterfaces/metrics.go b/einterfaces/metrics.go index b1fcb19bbbf..3ba92eef43f 100644 --- a/einterfaces/metrics.go +++ b/einterfaces/metrics.go @@ -66,6 +66,13 @@ type MetricsInterface interface { ObserveEnabledUsers(users int64) GetLoggerMetricsCollector() logr.MetricsCollector + IncrementRemoteClusterMsgSentCounter(remoteID string) + IncrementRemoteClusterMsgReceivedCounter(remoteID string) + IncrementRemoteClusterMsgErrorsCounter(remoteID string, timeout bool) + ObserveRemoteClusterPingDuration(remoteID string, elapsed float64) + ObserveRemoteClusterClockSkew(remoteID string, skew float64) + IncrementRemoteClusterConnStateChangeCounter(remoteID string, online bool) + IncrementJobActive(jobType string) DecrementJobActive(jobType string) diff --git a/einterfaces/mocks/MetricsInterface.go b/einterfaces/mocks/MetricsInterface.go index ca9d0e79ecc..7a4ef274f5d 100644 --- a/einterfaces/mocks/MetricsInterface.go +++ b/einterfaces/mocks/MetricsInterface.go @@ -180,6 +180,26 @@ func (_m *MetricsInterface) IncrementPostsSearchCounter() { _m.Called() } +// IncrementRemoteClusterConnStateChangeCounter provides a mock function with given fields: remoteID, online +func (_m *MetricsInterface) IncrementRemoteClusterConnStateChangeCounter(remoteID string, online bool) { + _m.Called(remoteID, online) +} + +// IncrementRemoteClusterMsgErrorsCounter provides a mock function with given fields: remoteID, timeout +func (_m *MetricsInterface) IncrementRemoteClusterMsgErrorsCounter(remoteID string, timeout bool) { + _m.Called(remoteID, timeout) +} + +// IncrementRemoteClusterMsgReceivedCounter provides a mock function with given fields: remoteID +func (_m *MetricsInterface) IncrementRemoteClusterMsgReceivedCounter(remoteID string) { + _m.Called(remoteID) +} + +// IncrementRemoteClusterMsgSentCounter provides a mock function with given fields: remoteID +func (_m *MetricsInterface) IncrementRemoteClusterMsgSentCounter(remoteID string) { + _m.Called(remoteID) +} + // IncrementUserIndexCounter provides a mock function with given fields: func (_m *MetricsInterface) IncrementUserIndexCounter() { _m.Called() @@ -255,6 +275,16 @@ func (_m *MetricsInterface) ObservePostsSearchDuration(elapsed float64) { _m.Called(elapsed) } +// ObserveRemoteClusterClockSkew provides a mock function with given fields: remoteID, skew +func (_m *MetricsInterface) ObserveRemoteClusterClockSkew(remoteID string, skew float64) { + _m.Called(remoteID, skew) +} + +// ObserveRemoteClusterPingDuration provides a mock function with given fields: remoteID, elapsed +func (_m *MetricsInterface) ObserveRemoteClusterPingDuration(remoteID string, elapsed float64) { + _m.Called(remoteID, elapsed) +} + // ObserveStoreMethodDuration provides a mock function with given fields: method, success, elapsed func (_m *MetricsInterface) ObserveStoreMethodDuration(method string, success string, elapsed float64) { _m.Called(method, success, elapsed) diff --git a/go.tools.sum b/go.tools.sum index 147133a7bc0..e01ab912271 100644 --- a/go.tools.sum +++ b/go.tools.sum @@ -387,6 +387,8 @@ github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FI github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= +github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ= +github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= @@ -469,6 +471,8 @@ github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69 github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ= github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= +github.com/tinylib/msgp v1.1.5 h1:2gXmtWueD2HefZHQe1QOy9HVzmFrLOVvsXwXBQ0ayy0= +github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 h1:OXcKh35JaYsGMRzpvFkLv/MEyPuL49CThT1pZ8aSml4= github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q= diff --git a/i18n/en.json b/i18n/en.json index 9ce2c51a435..4dacee5e8e7 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -1078,6 +1078,146 @@ "id": "api.command_open.name", "translation": "open" }, + { + "id": "api.command_remote.accept.help", + "translation": "Accept an invitation from a remote cluster" + }, + { + "id": "api.command_remote.accept_invitation", + "translation": "Invitation accepted and confirmed.\nSiteURL: {{.SiteURL}}" + }, + { + "id": "api.command_remote.accept_invitation.error", + "translation": "Could not accept invitation: {{.Error}}" + }, + { + "id": "api.command_remote.add_remote.error", + "translation": "Could not add remote cluster: {{.Error}}" + }, + { + "id": "api.command_remote.cluster_removed", + "translation": "Remote cluster {{.RemoteId}} {{.Result}}." + }, + { + "id": "api.command_remote.decode_invitation.error", + "translation": "Could not decode invitation: {{.Error}}" + }, + { + "id": "api.command_remote.desc", + "translation": "Invite remote Mattermost clusters for inter-cluster communication." + }, + { + "id": "api.command_remote.encrypt_invitation.error", + "translation": "Could not encrypt invitation: {{.Error}}" + }, + { + "id": "api.command_remote.fetch_status.error", + "translation": "Could not fetch remote clusters: {{.Error}}" + }, + { + "id": "api.command_remote.hint", + "translation": "[action]" + }, + { + "id": "api.command_remote.incorrect_password.error", + "translation": "Could not decrypt invitation. Incorrect password or corrupt invitation: {{.Error}}" + }, + { + "id": "api.command_remote.invitation.help", + "translation": "Invitation from remote cluster" + }, + { + "id": "api.command_remote.invitation.hint", + "translation": "The encrypted invitation from a remote cluster" + }, + { + "id": "api.command_remote.invitation_created", + "translation": "Invitation created." + }, + { + "id": "api.command_remote.invite.help", + "translation": "Invite a remote cluster" + }, + { + "id": "api.command_remote.invite_password.help", + "translation": "Invitation password" + }, + { + "id": "api.command_remote.invite_password.hint", + "translation": "Password to be used to encrypt the invitation" + }, + { + "id": "api.command_remote.invite_summary", + "translation": "Send the following encrypted (AES256 + Base64) blob to the remote site administrator along with the password. They will use the `{{.Command}}` slash command to accept the invitation.\n\n```\n{{.Invitation}}\n```\n\n**Ensure the remote site can access your cluster via** {{.SiteURL}}" + }, + { + "id": "api.command_remote.missing_command", + "translation": "Missing command. Available actions: {{.Actions}}" + }, + { + "id": "api.command_remote.missing_empty", + "translation": "Missing or empty `{{.Arg}}`" + }, + { + "id": "api.command_remote.name", + "translation": "remote" + }, + { + "id": "api.command_remote.name.help", + "translation": "Remote cluster name" + }, + { + "id": "api.command_remote.name.hint", + "translation": "A display name for the remote cluster" + }, + { + "id": "api.command_remote.permission_required", + "translation": "You require `{{.Permission}}` permission to manage remote clusters." + }, + { + "id": "api.command_remote.remote_add_remove.help", + "translation": "Add/remove remote clusters. Available actions: {{.Actions}}" + }, + { + "id": "api.command_remote.remote_table_header", + "translation": "| Name | SiteURL | RemoteId | Invite Accepted | Online | Last Ping |" + }, + { + "id": "api.command_remote.remotes_not_found", + "translation": "No remote clusters found." + }, + { + "id": "api.command_remote.remove.help", + "translation": "Removes a remote cluster" + }, + { + "id": "api.command_remote.remove_remote.error", + "translation": "Could not remove remote cluster: {{.Error}}" + }, + { + "id": "api.command_remote.remove_remote_id.help", + "translation": "Id of remote cluster remove" + }, + { + "id": "api.command_remote.service_disabled", + "translation": "Remote Cluster Service is disabled." + }, + { + "id": "api.command_remote.service_not_enabled", + "translation": "Remote Cluster service not enabled." + }, + { + "id": "api.command_remote.site_url_not_set", + "translation": "SiteURL not set. Please set this via the system console." + }, + { + "id": "api.command_remote.status.help", + "translation": "Displays status for all remote clusters" + }, + { + "id": "api.command_remote.unknown_action", + "translation": "Unknown action `{{.Action}}`" + }, { "id": "api.command_remove.desc", "translation": "Remove a member from the channel" @@ -1142,6 +1282,214 @@ "id": "api.command_settings.unsupported.app_error", "translation": "The settings command is not supported on your device." }, + { + "id": "api.command_share.available_actions", + "translation": "Available actions: {{.Actions}}" + }, + { + "id": "api.command_share.channel_display_name.help", + "translation": "Channel display name provided to remote instances" + }, + { + "id": "api.command_share.channel_display_name.hint", + "translation": "[displayname] - defaults to channel displayname" + }, + { + "id": "api.command_share.channel_header.help", + "translation": "Channel header provided to remote instances" + }, + { + "id": "api.command_share.channel_header.hint", + "translation": "[header] - defaults to channels header" + }, + { + "id": "api.command_share.channel_invite.error", + "translation": "Error inviting `{{.Name}}` to this channel: {{.Error}}" + }, + { + "id": "api.command_share.channel_invite_not_home.error", + "translation": "Cannot invite remote cluster to a shared channel originating somewhere else." + }, + { + "id": "api.command_share.channel_name.help", + "translation": "Channel name provided to remote instances" + }, + { + "id": "api.command_share.channel_name.hint", + "translation": "[name] - defaults to channel name" + }, + { + "id": "api.command_share.channel_purpose.help", + "translation": "Channel purpose provided to remote instances" + }, + { + "id": "api.command_share.channel_purpose.hint", + "translation": "[purpose] - defaults to channel purpose" + }, + { + "id": "api.command_share.channel_remote_id_not_exists", + "translation": "Shared channel remote id `{{.RemoteId}}` does not exist for this channel." + }, + { + "id": "api.command_share.channel_shared", + "translation": "This channel is now shared." + }, + { + "id": "api.command_share.channel_status.help", + "translation": "Displays status for this shared channel" + }, + { + "id": "api.command_share.channel_status_id", + "translation": "Status for channel Id `{{.ChannelId}}`" + }, + { + "id": "api.command_share.check_channel_exist.error", + "translation": "Error while checking if shared channel exists: {{.Error}}" + }, + { + "id": "api.command_share.could_not_uninvite.error", + "translation": "Could not uninvite `{{.RemoteId}}`: {{.Error}}" + }, + { + "id": "api.command_share.desc", + "translation": "Shares the current channel with a remote Mattermost instance." + }, + { + "id": "api.command_share.description_invite.help", + "translation": "Description for invite" + }, + { + "id": "api.command_share.description_invite.hint", + "translation": "[description] - optional" + }, + { + "id": "api.command_share.fetch_remote.error", + "translation": "Error fetching remote clusters: {{.Error}}" + }, + { + "id": "api.command_share.fetch_remote_status.error", + "translation": "Could not fetch status for remotes: {{.Error}}." + }, + { + "id": "api.command_share.hint", + "translation": "[action]" + }, + { + "id": "api.command_share.invalid_value.error", + "translation": "Invalid value for '{{.Arg}}': {{.Error}}" + }, + { + "id": "api.command_share.invitation_sent", + "translation": "Channel invitation has been sent to `{{.Name}} {{.SiteURL}}`." + }, + { + "id": "api.command_share.invite_remote.help", + "translation": "Invites a remote instance to the current shared channel" + }, + { + "id": "api.command_share.missing_action", + "translation": "Missing action. Available actions: {{.Actions}}" + }, + { + "id": "api.command_share.must_specify_valid_remote", + "translation": "Must specify a valid remote cluster id to invite." + }, + { + "id": "api.command_share.name", + "translation": "share" + }, + { + "id": "api.command_share.no_remote_invited", + "translation": "No remotes have been invited to this shared channel." + }, + { + "id": "api.command_share.not_shared_channel_unshare", + "translation": "Cannot unshare a channel that is not shared." + }, + { + "id": "api.command_share.permission_required", + "translation": "You require `{{.Permission}}` permission to manage shared channels." + }, + { + "id": "api.command_share.remote_already_invited", + "translation": "The remote cluster has already been invited." + }, + { + "id": "api.command_share.remote_id.help", + "translation": "Id of an existing remote instance. See `remote` command to add a remote instance." + }, + { + "id": "api.command_share.remote_id_invalid.error", + "translation": "Remote cluster id is invalid: {{.Error}}" + }, + { + "id": "api.command_share.remote_not_valid", + "translation": "Must specify a valid remote cluster to uninvite" + }, + { + "id": "api.command_share.remote_table_header", + "translation": "| Remote | SiteURL | Description | ReadOnly | InviteAccepted | Online | Last Sync |" + }, + { + "id": "api.command_share.remote_uninvited", + "translation": "Remote `{{.RemoteId}}` uninvited." + }, + { + "id": "api.command_share.service_disabled", + "translation": "Shared Channels Service is disabled.." + }, + { + "id": "api.command_share.share_channel.error", + "translation": "Cannot share this channel: {{.Error}}" + }, + { + "id": "api.command_share.share_current", + "translation": "Share the current channel" + }, + { + "id": "api.command_share.share_read_only.help", + "translation": "Channel will be shared in read-only mode" + }, + { + "id": "api.command_share.share_read_only.hint", + "translation": "[readonly] - 'Y' or 'N'. Defaults to 'N'" + }, + { + "id": "api.command_share.shared_channel_not_deleted", + "translation": "Shared channel was not deleted: `{{.Arg}}` must be `{{.Expected}}`." + }, + { + "id": "api.command_share.shared_channel_unavailable", + "translation": "This channel is no longer shared." + }, + { + "id": "api.command_share.shared_channel_unshare.error", + "translation": "Cannot unshare this channel: {{.Error}}." + }, + { + "id": "api.command_share.uninvite_remote.help", + "translation": "Uninvites a remote instance from this shared channel" + }, + { + "id": "api.command_share.uninvite_remote_id.help", + "translation": "Id of remote instance to uninvite." + }, + { + "id": "api.command_share.unknown_action", + "translation": "Unknown action `{{.Action}}`. Available actions: {{.Actions}}" + }, + { + "id": "api.command_share.unshare_channel.help", + "translation": "Unshares the current channel" + }, + { + "id": "api.command_share.unshare_confirmation.help", + "translation": "Are you sure? This channel will be unshared and all remote instances will be uninvited" + }, + { + "id": "api.command_share.unshare_confirmation.hint", + "translation": "'Y' or 'N'" + }, { "id": "api.command_shortcuts.desc", "translation": "Displays a list of keyboard shortcuts" @@ -1218,6 +1566,14 @@ "id": "api.context.invalid_url_param.app_error", "translation": "Invalid or missing {{.Name}} parameter in request URL." }, + { + "id": "api.context.invitation_expired.error", + "translation": "Invitation is expired." + }, + { + "id": "api.context.json_encoding.app_error", + "translation": "Error encoding JSON." + }, { "id": "api.context.local_origin_required.app_error", "translation": "This endpoint requires a local request origin." @@ -1230,6 +1586,18 @@ "id": "api.context.permissions.app_error", "translation": "You do not have the appropriate permissions." }, + { + "id": "api.context.remote_id_invalid.app_error", + "translation": "Unable to find remote cluster id {{.RemoteId}}." + }, + { + "id": "api.context.remote_id_mismatch.app_error", + "translation": "Remote cluster id mismatch." + }, + { + "id": "api.context.remote_id_missing.app_error", + "translation": "Remote cluster id missing." + }, { "id": "api.context.server_busy.app_error", "translation": "Server is busy, non-critical services are temporarily unavailable." @@ -2026,6 +2394,42 @@ "id": "api.reaction.town_square_read_only", "translation": "Reacting to posts is not possible in read-only channels." }, + { + "id": "api.remote_cluster.delete.app_error", + "translation": "We encountered an error deleting the remote cluster." + }, + { + "id": "api.remote_cluster.get.app_error", + "translation": "We encountered an error retrieving a remote cluster." + }, + { + "id": "api.remote_cluster.invalid_id.app_error", + "translation": "Invalid id." + }, + { + "id": "api.remote_cluster.invalid_topic.app_error", + "translation": "Invalid topic." + }, + { + "id": "api.remote_cluster.save.app_error", + "translation": "We encountered an error saving the remote cluster." + }, + { + "id": "api.remote_cluster.save_not_unique.app_error", + "translation": "Remote cluster has already been added." + }, + { + "id": "api.remote_cluster.service_not_enabled.app_error", + "translation": "The remote cluster service is not enabled." + }, + { + "id": "api.remote_cluster.update.app_error", + "translation": "We encountered an error updating the remote cluster." + }, + { + "id": "api.remote_cluster.update_not_unique.app_error", + "translation": "Remote cluster with the same url already exists." + }, { "id": "api.restricted_system_admin", "translation": "This action is forbidden to a restricted system admin." @@ -5558,6 +5962,10 @@ "id": "app.session.update_device_id.app_error", "translation": "Unable to update the device id." }, + { + "id": "app.sharedchannel.dm_channel_creation.internal_error", + "translation": "Encountered an error while creating a direct shared channel." + }, { "id": "app.status.get.app_error", "translation": "Encountered an error retrieving the status." @@ -7250,6 +7658,10 @@ "id": "model.channel.is_valid.creator_id.app_error", "translation": "Invalid creator id." }, + { + "id": "model.channel.is_valid.description.app_error", + "translation": "Invalid description." + }, { "id": "model.channel.is_valid.display_name.app_error", "translation": "Invalid display name." @@ -8566,6 +8978,14 @@ "id": "searchengine.bleve.disabled.error", "translation": "Error purging Bleve indexes: engine is disabled" }, + { + "id": "sharedchannel.cannot_deliver_post", + "translation": "One or more posts could not be delivered to remote site {{.Remote}} because it is offline. The post(s) will be delivered when the site is online." + }, + { + "id": "sharedchannel.permalink.not_found", + "translation": "This post contains permalinks to other channels which may not be visible to users in other sites." + }, { "id": "store.sql.convert_string_array", "translation": "FromDb: Unable to convert StringArray to *string" diff --git a/model/auditconv.go b/model/auditconv.go index 24a8d94df04..195ad11e491 100644 --- a/model/auditconv.go +++ b/model/auditconv.go @@ -51,6 +51,8 @@ func AuditModelTypeConv(val interface{}) (newVal interface{}, converted bool) { return newAuditIncomingWebhook(v), true case *OutgoingWebhook: return newAuditOutgoingWebhook(v), true + case *RemoteCluster: + return newRemoteCluster(v), true } return val, false } @@ -667,3 +669,42 @@ func (h auditOutgoingWebhook) MarshalJSONObject(enc *gojay.Encoder) { func (h auditOutgoingWebhook) IsNil() bool { return false } + +type auditRemoteCluster struct { + RemoteId string + RemoteTeamId string + DisplayName string + SiteURL string + CreateAt int64 + LastPingAt int64 + CreatorId string +} + +// newRemoteCluster creates a simplified representation of RemoteCluster for output to audit log. +func newRemoteCluster(r *RemoteCluster) auditRemoteCluster { + var rc auditRemoteCluster + if r != nil { + rc.RemoteId = r.RemoteId + rc.RemoteTeamId = r.RemoteTeamId + rc.DisplayName = r.DisplayName + rc.SiteURL = r.SiteURL + rc.CreateAt = r.CreateAt + rc.LastPingAt = r.LastPingAt + rc.CreatorId = r.CreatorId + } + return rc +} + +func (r auditRemoteCluster) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("remote_id", r.RemoteId) + enc.StringKey("remote_team_id", r.RemoteTeamId) + enc.StringKey("display_name", r.DisplayName) + enc.StringKey("site_url", r.SiteURL) + enc.Int64Key("create_at", r.CreateAt) + enc.Int64Key("last_ping_at", r.LastPingAt) + enc.StringKey("creator_id", r.CreatorId) +} + +func (r auditRemoteCluster) IsNil() bool { + return false +} diff --git a/model/channel.go b/model/channel.go index c7a0a1bfc70..15c080a1452 100644 --- a/model/channel.go +++ b/model/channel.go @@ -142,6 +142,14 @@ type ChannelMemberCountByGroup struct { ChannelMemberTimezonesCount int64 `db:"-" json:"channel_member_timezones_count"` } +type ChannelOption func(channel *Channel) + +func WithID(ID string) ChannelOption { + return func(channel *Channel) { + channel.Id = ID + } +} + func (o *Channel) DeepCopy() *Channel { copy := *o if copy.SchemeId != nil { diff --git a/model/client4.go b/model/client4.go index e1974ff2f46..5c4e513eb42 100644 --- a/model/client4.go +++ b/model/client4.go @@ -19,27 +19,29 @@ import ( ) const ( - HEADER_REQUEST_ID = "X-Request-ID" - HEADER_VERSION_ID = "X-Version-ID" - HEADER_CLUSTER_ID = "X-Cluster-ID" - HEADER_ETAG_SERVER = "ETag" - HEADER_ETAG_CLIENT = "If-None-Match" - HEADER_FORWARDED = "X-Forwarded-For" - HEADER_REAL_IP = "X-Real-IP" - HEADER_FORWARDED_PROTO = "X-Forwarded-Proto" - HEADER_TOKEN = "token" - HEADER_CSRF_TOKEN = "X-CSRF-Token" - HEADER_BEARER = "BEARER" - HEADER_AUTH = "Authorization" - HEADER_CLOUD_TOKEN = "X-Cloud-Token" - HEADER_REQUESTED_WITH = "X-Requested-With" - HEADER_REQUESTED_WITH_XML = "XMLHttpRequest" - HEADER_RANGE = "Range" - STATUS = "status" - STATUS_OK = "OK" - STATUS_FAIL = "FAIL" - STATUS_UNHEALTHY = "UNHEALTHY" - STATUS_REMOVE = "REMOVE" + HEADER_REQUEST_ID = "X-Request-ID" + HEADER_VERSION_ID = "X-Version-ID" + HEADER_CLUSTER_ID = "X-Cluster-ID" + HEADER_ETAG_SERVER = "ETag" + HEADER_ETAG_CLIENT = "If-None-Match" + HEADER_FORWARDED = "X-Forwarded-For" + HEADER_REAL_IP = "X-Real-IP" + HEADER_FORWARDED_PROTO = "X-Forwarded-Proto" + HEADER_TOKEN = "token" + HEADER_CSRF_TOKEN = "X-CSRF-Token" + HEADER_BEARER = "BEARER" + HEADER_AUTH = "Authorization" + HEADER_CLOUD_TOKEN = "X-Cloud-Token" + HEADER_REMOTECLUSTER_TOKEN = "X-RemoteCluster-Token" + HEADER_REMOTECLUSTER_ID = "X-RemoteCluster-Id" + HEADER_REQUESTED_WITH = "X-Requested-With" + HEADER_REQUESTED_WITH_XML = "XMLHttpRequest" + HEADER_RANGE = "Range" + STATUS = "status" + STATUS_OK = "OK" + STATUS_FAIL = "FAIL" + STATUS_UNHEALTHY = "UNHEALTHY" + STATUS_REMOVE = "REMOVE" CLIENT_DIR = "client" @@ -559,6 +561,14 @@ func (c *Client4) GetExportRoute(name string) string { return fmt.Sprintf(c.GetExportsRoute()+"/%v", name) } +func (c *Client4) GetRemoteClusterRoute() string { + return "/remotecluster" +} + +func (c *Client4) GetSharedChannelsRoute() string { + return "/sharedchannels" +} + func (c *Client4) DoApiGet(url string, etag string) (*http.Response, *AppError) { return c.DoApiRequest(http.MethodGet, c.ApiUrl+url, "", etag) } @@ -5999,3 +6009,31 @@ func (c *Client4) SendAdminUpgradeRequestEmailOnJoin() *Response { return BuildResponse(r) } + +func (c *Client4) GetAllSharedChannels(teamID string, page, perPage int) ([]*SharedChannel, *Response) { + url := fmt.Sprintf("%s/%s?page=%d&per_page=%d", c.GetSharedChannelsRoute(), teamID, page, perPage) + r, appErr := c.DoApiGet(url, "") + if appErr != nil { + return nil, BuildErrorResponse(r, appErr) + } + defer closeBody(r) + + var channels []*SharedChannel + json.NewDecoder(r.Body).Decode(&channels) + + return channels, BuildResponse(r) +} + +func (c *Client4) GetRemoteClusterInfo(remoteID string) (RemoteClusterInfo, *Response) { + url := fmt.Sprintf("%s/remote_info/%s", c.GetSharedChannelsRoute(), remoteID) + r, appErr := c.DoApiGet(url, "") + if appErr != nil { + return RemoteClusterInfo{}, BuildErrorResponse(r, appErr) + } + defer closeBody(r) + + var rci RemoteClusterInfo + json.NewDecoder(r.Body).Decode(&rci) + + return rci, BuildResponse(r) +} diff --git a/model/config.go b/model/config.go index 1a633a8f5b8..d63ccb04ecb 100644 --- a/model/config.go +++ b/model/config.go @@ -940,6 +940,7 @@ type ExperimentalSettings struct { CloudUserLimit *int64 `access:"experimental,write_restrictable"` CloudBilling *bool `access:"experimental,write_restrictable"` EnableSharedChannels *bool `access:"experimental"` + EnableRemoteClusterService *bool `access:"experimental"` } func (s *ExperimentalSettings) SetDefaults() { @@ -979,6 +980,10 @@ func (s *ExperimentalSettings) SetDefaults() { if s.EnableSharedChannels == nil { s.EnableSharedChannels = NewBool(false) } + + if s.EnableRemoteClusterService == nil { + s.EnableRemoteClusterService = NewBool(false) + } } type AnalyticsSettings struct { diff --git a/model/feature_flags.go b/model/feature_flags.go index cdcf87dab78..148e16e4b25 100644 --- a/model/feature_flags.go +++ b/model/feature_flags.go @@ -19,6 +19,12 @@ type FeatureFlags struct { // Toggle on and off support for Collapsed Threads CollapsedThreads bool + // Enable the remote cluster service for shared channels. + EnableRemoteClusterService bool + + // Toggle on and off support for Custom User Statuses + CustomUserStatuses bool + // AppsEnabled toggle the Apps framework functionalities both in server and client side AppsEnabled bool @@ -37,6 +43,7 @@ func (f *FeatureFlags) SetDefaults() { f.TestBoolFeature = false f.CloudDelinquentEmailJobsEnabled = false f.CollapsedThreads = false + f.EnableRemoteClusterService = false f.FilesSearch = false f.AppsEnabled = false diff --git a/model/file_info.go b/model/file_info.go index 0da93380086..f86c22b506c 100644 --- a/model/file_info.go +++ b/model/file_info.go @@ -61,6 +61,7 @@ type FileInfo struct { HasPreviewImage bool `json:"has_preview_image,omitempty"` MiniPreview *[]byte `json:"mini_preview"` // declared as *[]byte to avoid postgres/mysql differences in deserialization Content string `json:"-"` + RemoteId *string `json:"remote_id"` } func (fi *FileInfo) ToJson() string { @@ -105,6 +106,10 @@ func (fi *FileInfo) PreSave() { if fi.UpdateAt < fi.CreateAt { fi.UpdateAt = fi.CreateAt } + + if fi.RemoteId == nil { + fi.RemoteId = NewString("") + } } func (fi *FileInfo) IsValid() *AppError { diff --git a/model/license.go b/model/license.go index b4a294ee71a..e180a356d2a 100644 --- a/model/license.go +++ b/model/license.go @@ -240,7 +240,7 @@ func (f *Features) SetDefaults() { } if f.RemoteClusterService == nil { - f.RemoteClusterService = f.SharedChannels + f.RemoteClusterService = NewBool(*f.FutureFeatures) } } diff --git a/model/post.go b/model/post.go index ebb69767e2a..fe5df8347b3 100644 --- a/model/post.go +++ b/model/post.go @@ -96,6 +96,7 @@ type Post struct { FileIds StringArray `json:"file_ids,omitempty"` PendingPostId string `json:"pending_post_id" db:"-"` HasReactions bool `json:"has_reactions,omitempty"` + RemoteId *string `json:"remote_id,omitempty"` // Transient data populated before sending a post to the client ReplyCount int64 `json:"reply_count" db:"-"` @@ -206,6 +207,7 @@ func (o *Post) ShallowCopy(dst *Post) error { dst.Participants = o.Participants dst.LastReplyAt = o.LastReplyAt dst.Metadata = o.Metadata + dst.RemoteId = o.RemoteId return nil } @@ -235,6 +237,18 @@ type GetPostsSinceOptions struct { SkipFetchThreads bool CollapsedThreads bool CollapsedThreadsExtended bool + SortAscending bool +} + +type GetPostsSinceForSyncOptions struct { + ChannelId string + Since int64 // inclusive + Until int64 // inclusive + SortDescending bool + ExcludeRemoteId string + IncludeDeleted bool + Limit int + Offset int } type GetPostsOptions struct { @@ -452,6 +466,11 @@ func (o *Post) IsSystemMessage() bool { return len(o.Type) >= len(POST_SYSTEM_MESSAGE_PREFIX) && o.Type[:len(POST_SYSTEM_MESSAGE_PREFIX)] == POST_SYSTEM_MESSAGE_PREFIX } +// IsRemote returns true if the post originated on a remote cluster. +func (o *Post) IsRemote() bool { + return o.RemoteId != nil && *o.RemoteId != "" +} + func (o *Post) IsJoinLeaveMessage() bool { return o.Type == POST_JOIN_LEAVE || o.Type == POST_ADD_REMOVE || diff --git a/model/post_list.go b/model/post_list.go index ebf7bef7545..ba84f7493d7 100644 --- a/model/post_list.go +++ b/model/post_list.go @@ -27,6 +27,11 @@ func NewPostList() *PostList { func (o *PostList) ToSlice() []*Post { var posts []*Post + + if l := len(o.Posts); l > 0 { + posts = make([]*Post, 0, l) + } + for _, id := range o.Order { posts = append(posts, o.Posts[id]) } diff --git a/model/reaction.go b/model/reaction.go index 1c4706860b0..6d0ea68d232 100644 --- a/model/reaction.go +++ b/model/reaction.go @@ -11,12 +11,13 @@ import ( ) type Reaction struct { - UserId string `json:"user_id"` - PostId string `json:"post_id"` - EmojiName string `json:"emoji_name"` - CreateAt int64 `json:"create_at"` - UpdateAt int64 `json:"update_at"` - DeleteAt int64 `json:"delete_at"` + UserId string `json:"user_id"` + PostId string `json:"post_id"` + EmojiName string `json:"emoji_name"` + CreateAt int64 `json:"create_at"` + UpdateAt int64 `json:"update_at"` + DeleteAt int64 `json:"delete_at"` + RemoteId *string `json:"remote_id"` } func (o *Reaction) ToJson() string { @@ -94,8 +95,16 @@ func (o *Reaction) PreSave() { } o.UpdateAt = GetMillis() o.DeleteAt = 0 + + if o.RemoteId == nil { + o.RemoteId = NewString("") + } } func (o *Reaction) PreUpdate() { o.UpdateAt = GetMillis() + + if o.RemoteId == nil { + o.RemoteId = NewString("") + } } diff --git a/model/remote_cluster.go b/model/remote_cluster.go new file mode 100644 index 00000000000..32b2a462783 --- /dev/null +++ b/model/remote_cluster.go @@ -0,0 +1,295 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha512" + "encoding/json" + "io" + "net/http" + "strings" +) + +const ( + RemoteOfflineAfterMillis = 1000 * 60 * 5 // 5 minutes +) + +type RemoteCluster struct { + RemoteId string `json:"remote_id"` + RemoteTeamId string `json:"remote_team_id"` + DisplayName string `json:"display_name"` + SiteURL string `json:"site_url"` + CreateAt int64 `json:"create_at"` + LastPingAt int64 `json:"last_ping_at"` + Token string `json:"token"` + RemoteToken string `json:"remote_token"` + Topics string `json:"topics"` + CreatorId string `json:"creator_id"` +} + +func (rc *RemoteCluster) PreSave() { + if rc.RemoteId == "" { + rc.RemoteId = NewId() + } + + if rc.Token == "" { + rc.Token = NewId() + } + + if rc.CreateAt == 0 { + rc.CreateAt = GetMillis() + } + rc.fixTopics() +} + +func (rc *RemoteCluster) IsValid() *AppError { + if !IsValidId(rc.RemoteId) { + return NewAppError("RemoteCluster.IsValid", "model.cluster.is_valid.id.app_error", nil, "id="+rc.RemoteId, http.StatusBadRequest) + } + + if rc.DisplayName == "" { + return NewAppError("RemoteCluster.IsValid", "model.cluster.is_valid.name.app_error", nil, "display_name empty", http.StatusBadRequest) + } + + if rc.CreateAt == 0 { + return NewAppError("RemoteCluster.IsValid", "model.cluster.is_valid.create_at.app_error", nil, "create_at=0", http.StatusBadRequest) + } + + if !IsValidId(rc.CreatorId) { + return NewAppError("RemoteCluster.IsValid", "model.cluster.is_valid.id.app_error", nil, "creator_id="+rc.CreatorId, http.StatusBadRequest) + } + return nil +} + +func (rc *RemoteCluster) PreUpdate() { + rc.fixTopics() +} + +func (rc *RemoteCluster) IsOnline() bool { + return rc.LastPingAt > GetMillis()-RemoteOfflineAfterMillis +} + +// fixTopics ensures all topics are separated by one, and only one, space. +func (rc *RemoteCluster) fixTopics() { + trimmed := strings.TrimSpace(rc.Topics) + if trimmed == "" || trimmed == "*" { + rc.Topics = trimmed + return + } + + var sb strings.Builder + sb.WriteString(" ") + + ss := strings.Split(rc.Topics, " ") + for _, c := range ss { + cc := strings.TrimSpace(c) + if cc != "" { + sb.WriteString(cc) + sb.WriteString(" ") + } + } + rc.Topics = sb.String() +} + +func (rc *RemoteCluster) ToJSON() (string, error) { + b, err := json.Marshal(rc) + if err != nil { + return "", err + } + return string(b), nil +} + +func (rc *RemoteCluster) ToRemoteClusterInfo() RemoteClusterInfo { + return RemoteClusterInfo{ + DisplayName: rc.DisplayName, + CreateAt: rc.CreateAt, + LastPingAt: rc.LastPingAt, + } +} + +func RemoteClusterFromJSON(data io.Reader) (*RemoteCluster, *AppError) { + var rc RemoteCluster + err := json.NewDecoder(data).Decode(&rc) + if err != nil { + return nil, NewAppError("RemoteClusterFromJSON", "model.utils.decode_json.app_error", nil, err.Error(), http.StatusBadRequest) + } + return &rc, nil +} + +// RemoteClusterInfo provides a subset of RemoteCluster fields suitable for sending to clients. +type RemoteClusterInfo struct { + DisplayName string `json:"display_name"` + CreateAt int64 `json:"create_at"` + LastPingAt int64 `json:"last_ping_at"` +} + +// RemoteClusterFrame wraps a `RemoteClusterMsg` with credentials specific to a remote cluster. +type RemoteClusterFrame struct { + RemoteId string `json:"remote_id"` + Msg RemoteClusterMsg `json:"msg"` +} + +func (f *RemoteClusterFrame) IsValid() *AppError { + if !IsValidId(f.RemoteId) { + return NewAppError("RemoteClusterFrame.IsValid", "api.remote_cluster.invalid_id.app_error", nil, "RemoteId="+f.RemoteId, http.StatusBadRequest) + } + + if err := f.Msg.IsValid(); err != nil { + return err + } + + return nil +} + +func RemoteClusterFrameFromJSON(data io.Reader) (*RemoteClusterFrame, *AppError) { + var frame RemoteClusterFrame + err := json.NewDecoder(data).Decode(&frame) + if err != nil { + return nil, NewAppError("RemoteClusterFrameFromJSON", "model.utils.decode_json.app_error", nil, err.Error(), http.StatusBadRequest) + } + return &frame, nil +} + +// RemoteClusterMsg represents a message that is sent and received between clusters. +// These are processed and routed via the RemoteClusters service. +type RemoteClusterMsg struct { + Id string `json:"id"` + Topic string `json:"topic"` + CreateAt int64 `json:"create_at"` + Payload json.RawMessage `json:"payload"` +} + +func NewRemoteClusterMsg(topic string, payload json.RawMessage) RemoteClusterMsg { + return RemoteClusterMsg{ + Id: NewId(), + Topic: topic, + CreateAt: GetMillis(), + Payload: payload, + } +} + +func (m RemoteClusterMsg) IsValid() *AppError { + if !IsValidId(m.Id) { + return NewAppError("RemoteClusterMsg.IsValid", "api.remote_cluster.invalid_id.app_error", nil, "Id="+m.Id, http.StatusBadRequest) + } + + if m.Topic == "" { + return NewAppError("RemoteClusterMsg.IsValid", "api.remote_cluster.invalid_topic.app_error", nil, "Topic empty", http.StatusBadRequest) + } + + if len(m.Payload) == 0 { + return NewAppError("RemoteClusterMsg.IsValid", "api.context.invalid_body_param.app_error", map[string]interface{}{"Name": "PayLoad"}, "", http.StatusBadRequest) + } + + return nil +} + +func RemoteClusterMsgFromJSON(data io.Reader) (RemoteClusterMsg, *AppError) { + var msg RemoteClusterMsg + err := json.NewDecoder(data).Decode(&msg) + if err != nil { + return RemoteClusterMsg{}, NewAppError("RemoteClusterMsgFromJSON", "model.utils.decode_json.app_error", nil, err.Error(), http.StatusBadRequest) + } + return msg, nil +} + +// RemoteClusterPing represents a ping that is sent and received between clusters +// to indicate a connection is alive. This is the payload for a `RemoteClusterMsg`. +type RemoteClusterPing struct { + SentAt int64 `json:"sent_at"` + RecvAt int64 `json:"recv_at"` +} + +func RemoteClusterPingFromRawJSON(raw json.RawMessage) (RemoteClusterPing, *AppError) { + var ping RemoteClusterPing + err := json.Unmarshal(raw, &ping) + if err != nil { + return RemoteClusterPing{}, NewAppError("RemoteClusterPingFromRawJSON", "model.utils.decode_json.app_error", nil, err.Error(), http.StatusBadRequest) + } + return ping, nil +} + +// RemoteClusterInvite represents an invitation to establish a simple trust with a remote cluster. +type RemoteClusterInvite struct { + RemoteId string `json:"remote_id"` + RemoteTeamId string `json:"remote_team_id"` + SiteURL string `json:"site_url"` + Token string `json:"token"` +} + +func RemoteClusterInviteFromRawJSON(raw json.RawMessage) (*RemoteClusterInvite, *AppError) { + var invite RemoteClusterInvite + err := json.Unmarshal(raw, &invite) + if err != nil { + return nil, NewAppError("RemoteClusterInviteFromRawJSON", "model.utils.decode_json.app_error", nil, err.Error(), http.StatusBadRequest) + } + return &invite, nil +} + +func (rci *RemoteClusterInvite) Encrypt(password string) ([]byte, error) { + raw, err := json.Marshal(&rci) + if err != nil { + return nil, err + } + + // hash the pasword to 32 bytes for AES256 + key := sha512.Sum512_256([]byte(password)) + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + // create random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + // prefix the nonce to the cyphertext so we don't need to keep track of it. + return gcm.Seal(nonce, nonce, raw, nil), nil +} + +func (rci *RemoteClusterInvite) Decrypt(encrypted []byte, password string) error { + // hash the pasword to 32 bytes for AES256 + key := sha512.Sum512_256([]byte(password)) + block, err := aes.NewCipher(key[:]) + if err != nil { + return err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + + // nonce was prefixed to the cyphertext when encrypting so we need to extract it. + nonceSize := gcm.NonceSize() + nonce, cyphertext := encrypted[:nonceSize], encrypted[nonceSize:] + + plain, err := gcm.Open(nil, nonce, cyphertext, nil) + if err != nil { + return err + } + + // try to unmarshall the decrypted JSON to this invite struct. + return json.Unmarshal(plain, &rci) +} + +// RemoteClusterQueryFilter provides filter criteria for RemoteClusterStore.GetAll +type RemoteClusterQueryFilter struct { + ExcludeOffline bool + InChannel string + NotInChannel string + Topic string + CreatorId string + OnlyConfirmed bool +} diff --git a/model/remote_cluster_test.go b/model/remote_cluster_test.go new file mode 100644 index 00000000000..b2447452bc2 --- /dev/null +++ b/model/remote_cluster_test.go @@ -0,0 +1,158 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "crypto/rand" + "encoding/json" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoteClusterJson(t *testing.T) { + o := RemoteCluster{RemoteId: NewId(), DisplayName: "test"} + + json, err := o.ToJSON() + require.NoError(t, err) + + ro, err := RemoteClusterFromJSON(strings.NewReader(json)) + require.Nil(t, err) + + require.Equal(t, o.RemoteId, ro.RemoteId) + require.Equal(t, o.DisplayName, ro.DisplayName) +} + +func TestRemoteClusterIsValid(t *testing.T) { + id := NewId() + creator := NewId() + now := GetMillis() + data := []struct { + name string + rc *RemoteCluster + valid bool + }{ + {name: "Zero value", rc: &RemoteCluster{}, valid: false}, + {name: "Missing cluster_name", rc: &RemoteCluster{RemoteId: id}, valid: false}, + {name: "Missing host_name", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster"}, valid: false}, + {name: "Missing create_at", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "example.com"}, valid: false}, + {name: "Missing last_ping_at", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "example.com", CreatorId: creator, CreateAt: now}, valid: true}, + {name: "Missing creator", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "example.com", CreateAt: now, LastPingAt: now}, valid: false}, + {name: "RemoteCluster valid", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "example.com", CreateAt: now, LastPingAt: now, CreatorId: creator}, valid: true}, + {name: "Include protocol", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "http://example.com", CreateAt: now, LastPingAt: now, CreatorId: creator}, valid: true}, + {name: "Include protocol & port", rc: &RemoteCluster{RemoteId: id, DisplayName: "test cluster", SiteURL: "http://example.com:8065", CreateAt: now, LastPingAt: now, CreatorId: creator}, valid: true}, + } + + for _, item := range data { + err := item.rc.IsValid() + if item.valid { + assert.Nil(t, err, item.name) + } else { + assert.NotNil(t, err, item.name) + } + } +} + +func TestRemoteClusterPreSave(t *testing.T) { + now := GetMillis() + + o := RemoteCluster{RemoteId: NewId(), DisplayName: "test"} + o.PreSave() + + require.GreaterOrEqual(t, o.CreateAt, now) +} + +func TestRemoteClusterMsgJson(t *testing.T) { + o := NewRemoteClusterMsg("shared_channel", []byte("{\"hello\":\"world\"}")) + + json, err := json.Marshal(o) + require.NoError(t, err) + + ro, err := RemoteClusterMsgFromJSON(strings.NewReader(string(json))) + require.Nil(t, err) + + require.Equal(t, o.Id, ro.Id) + require.Equal(t, o.CreateAt, ro.CreateAt) + require.Equal(t, o.Topic, ro.Topic) +} + +func TestRemoteClusterMsgIsValid(t *testing.T) { + id := NewId() + now := GetMillis() + data := []struct { + name string + msg *RemoteClusterMsg + valid bool + }{ + {name: "Zero value", msg: &RemoteClusterMsg{}, valid: false}, + {name: "Missing remote id", msg: &RemoteClusterMsg{Id: id}, valid: false}, + {name: "Missing Topic", msg: &RemoteClusterMsg{Id: id}, valid: false}, + {name: "Missing Payload", msg: &RemoteClusterMsg{Id: id, CreateAt: now, Topic: "shared_channel"}, valid: false}, + {name: "RemoteClusterMsg valid", msg: &RemoteClusterMsg{Id: id, CreateAt: now, Topic: "shared_channel", Payload: []byte("{\"hello\":\"world\"}")}, valid: true}, + } + + for _, item := range data { + err := item.msg.IsValid() + if item.valid { + assert.Nil(t, err, item.name) + } else { + assert.NotNil(t, err, item.name) + } + } +} + +func TestFixTopics(t *testing.T) { + testData := []struct { + topics string + expected string + }{ + {topics: "", expected: ""}, + {topics: " ", expected: ""}, + {topics: "share", expected: " share "}, + {topics: "share incident", expected: " share incident "}, + {topics: " share incident ", expected: " share incident "}, + {topics: " share incident ", expected: " share incident "}, + } + + for _, tt := range testData { + rc := &RemoteCluster{Topics: tt.topics} + rc.fixTopics() + assert.Equal(t, tt.expected, rc.Topics) + } +} + +func TestRemoteClusterInviteEncryption(t *testing.T) { + testData := []struct { + name string + badDecrypt bool + password string + invite RemoteClusterInvite + }{ + {name: "empty password", badDecrypt: false, password: "", invite: RemoteClusterInvite{RemoteId: NewId(), SiteURL: "https://example.com:8065", Token: NewId()}}, + {name: "good password", badDecrypt: false, password: "Ultra secret password!", invite: RemoteClusterInvite{RemoteId: NewId(), SiteURL: "https://example.com:8065", Token: NewId()}}, + {name: "bad decrypt", badDecrypt: true, password: "correct horse battery staple", invite: RemoteClusterInvite{RemoteId: NewId(), SiteURL: "https://example.com:8065", Token: NewId()}}, + } + + for _, tt := range testData { + encrypted, err := tt.invite.Encrypt(tt.password) + require.NoError(t, err) + + invite := RemoteClusterInvite{} + if tt.badDecrypt { + buf := make([]byte, len(encrypted)) + _, err = io.ReadFull(rand.Reader, buf) + assert.NoError(t, err) + + err = invite.Decrypt(buf, tt.password) + require.Error(t, err) + } else { + err = invite.Decrypt(encrypted, tt.password) + require.NoError(t, err) + assert.Equal(t, tt.invite, invite) + } + } +} diff --git a/model/session.go b/model/session.go index 9b0bd628ed7..334c7175013 100644 --- a/model/session.go +++ b/model/session.go @@ -26,6 +26,7 @@ const ( SESSION_PROP_IS_BOT_VALUE = "true" SESSION_TYPE_USER_ACCESS_TOKEN = "UserAccessToken" SESSION_TYPE_CLOUD_KEY = "CloudKey" + SESSION_TYPE_REMOTECLUSTER_TOKEN = "RemoteClusterToken" SESSION_PROP_IS_GUEST = "is_guest" SESSION_ACTIVITY_TIMEOUT = 1000 * 60 * 5 // 5 minutes SESSION_USER_ACCESS_TOKEN_EXPIRY = 100 * 365 // 100 years diff --git a/model/shared_channel.go b/model/shared_channel.go new file mode 100644 index 00000000000..f2278e66d00 --- /dev/null +++ b/model/shared_channel.go @@ -0,0 +1,267 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "encoding/json" + "io" + "net/http" + "unicode/utf8" +) + +// SharedChannel represents a channel that can be synchronized with a remote cluster. +// If "home" is true, then the shared channel is homed locally and "SharedChannelRemote" +// table contains the remote clusters that have been invited. +// If "home" is false, then the shared channel is homed remotely, and "RemoteId" +// field points to the remote cluster connection in "RemoteClusters" table. +type SharedChannel struct { + ChannelId string `json:"channel_id"` + TeamId string `json:"team_id"` + Home bool `json:"home"` + ReadOnly bool `json:"readonly"` + ShareName string `json:"share_name"` + ShareDisplayName string `json:"share_displayname"` + SharePurpose string `json:"share_purpose"` + ShareHeader string `json:"share_header"` + CreatorId string `json:"creator_id"` + CreateAt int64 `json:"create_at"` + UpdateAt int64 `json:"update_at"` + RemoteId string `json:"remote_id,omitempty"` // if not "home" + Type string `db:"-"` +} + +func (sc *SharedChannel) ToJson() string { + b, _ := json.Marshal(sc) + return string(b) +} + +func SharedChannelFromJson(data io.Reader) (*SharedChannel, error) { + var sc *SharedChannel + err := json.NewDecoder(data).Decode(&sc) + return sc, err +} + +func (sc *SharedChannel) IsValid() *AppError { + if !IsValidId(sc.ChannelId) { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.id.app_error", nil, "ChannelId="+sc.ChannelId, http.StatusBadRequest) + } + + if sc.Type != CHANNEL_DIRECT && !IsValidId(sc.TeamId) { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.id.app_error", nil, "TeamId="+sc.TeamId, http.StatusBadRequest) + } + + if sc.CreateAt == 0 { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.create_at.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if sc.UpdateAt == 0 { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.update_at.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if utf8.RuneCountInString(sc.ShareDisplayName) > CHANNEL_DISPLAY_NAME_MAX_RUNES { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.display_name.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if !IsValidChannelIdentifier(sc.ShareName) { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.2_or_more.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if utf8.RuneCountInString(sc.ShareHeader) > CHANNEL_HEADER_MAX_RUNES { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.header.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if utf8.RuneCountInString(sc.SharePurpose) > CHANNEL_PURPOSE_MAX_RUNES { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.purpose.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if !IsValidId(sc.CreatorId) { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.creator_id.app_error", nil, "CreatorId="+sc.CreatorId, http.StatusBadRequest) + } + + if !sc.Home { + if !IsValidId(sc.RemoteId) { + return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.id.app_error", nil, "RemoteId="+sc.RemoteId, http.StatusBadRequest) + } + } + return nil +} + +func (sc *SharedChannel) PreSave() { + sc.ShareName = SanitizeUnicode(sc.ShareName) + sc.ShareDisplayName = SanitizeUnicode(sc.ShareDisplayName) + + sc.CreateAt = GetMillis() + sc.UpdateAt = sc.CreateAt +} + +func (sc *SharedChannel) PreUpdate() { + sc.UpdateAt = GetMillis() + sc.ShareName = SanitizeUnicode(sc.ShareName) + sc.ShareDisplayName = SanitizeUnicode(sc.ShareDisplayName) +} + +// SharedChannelRemote represents a remote cluster that has been invited +// to a shared channel. +type SharedChannelRemote struct { + Id string `json:"id"` + ChannelId string `json:"channel_id"` + Description string `json:"description"` + CreatorId string `json:"creator_id"` + CreateAt int64 `json:"create_at"` + UpdateAt int64 `json:"update_at"` + IsInviteAccepted bool `json:"is_invite_accepted"` + IsInviteConfirmed bool `json:"is_invite_confirmed"` + RemoteId string `json:"remote_id"` + NextSyncAt int64 `json:"next_sync_at"` +} + +func (sc *SharedChannelRemote) ToJson() string { + b, _ := json.Marshal(sc) + return string(b) +} + +func SharedChannelRemoteFromJson(data io.Reader) (*SharedChannelRemote, error) { + var sc *SharedChannelRemote + err := json.NewDecoder(data).Decode(&sc) + return sc, err +} + +func (sc *SharedChannelRemote) IsValid() *AppError { + if !IsValidId(sc.Id) { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.id.app_error", nil, "Id="+sc.Id, http.StatusBadRequest) + } + + if !IsValidId(sc.ChannelId) { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.id.app_error", nil, "ChannelId="+sc.ChannelId, http.StatusBadRequest) + } + + if len(sc.Description) > 64 { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.description.app_error", nil, "description="+sc.Description, http.StatusBadRequest) + } + + if sc.CreateAt == 0 { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.create_at.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if sc.UpdateAt == 0 { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.update_at.app_error", nil, "id="+sc.ChannelId, http.StatusBadRequest) + } + + if !IsValidId(sc.CreatorId) { + return NewAppError("SharedChannelRemote.IsValid", "model.channel.is_valid.creator_id.app_error", nil, "id="+sc.CreatorId, http.StatusBadRequest) + } + return nil +} + +func (sc *SharedChannelRemote) PreSave() { + if sc.Id == "" { + sc.Id = NewId() + } + sc.CreateAt = GetMillis() + sc.UpdateAt = sc.CreateAt +} + +func (sc *SharedChannelRemote) PreUpdate() { + sc.UpdateAt = GetMillis() +} + +type SharedChannelRemoteStatus struct { + ChannelId string `json:"channel_id"` + DisplayName string `json:"display_name"` + SiteURL string `json:"site_url"` + LastPingAt int64 `json:"last_ping_at"` + NextSyncAt int64 `json:"next_sync_at"` + Description string `json:"description"` + ReadOnly bool `json:"readonly"` + IsInviteAccepted bool `json:"is_invite_accepted"` + Token string `json:"token"` +} + +// SharedChannelUser stores a lastSyncAt timestamp on behalf of a remote cluster for +// each user that has been synchronized. +type SharedChannelUser struct { + Id string `json:"id"` + UserId string `json:"user_id"` + RemoteId string `json:"remote_id"` + CreateAt int64 `json:"create_at"` + LastSyncAt int64 `json:"last_sync_at"` +} + +func (scu *SharedChannelUser) PreSave() { + scu.Id = NewId() + scu.CreateAt = GetMillis() +} + +func (scu *SharedChannelUser) IsValid() *AppError { + if !IsValidId(scu.Id) { + return NewAppError("SharedChannelUser.IsValid", "model.channel.is_valid.id.app_error", nil, "Id="+scu.Id, http.StatusBadRequest) + } + + if !IsValidId(scu.UserId) { + return NewAppError("SharedChannelUser.IsValid", "model.channel.is_valid.id.app_error", nil, "UserId="+scu.UserId, http.StatusBadRequest) + } + + if !IsValidId(scu.RemoteId) { + return NewAppError("SharedChannelUser.IsValid", "model.channel.is_valid.id.app_error", nil, "RemoteId="+scu.RemoteId, http.StatusBadRequest) + } + + if scu.CreateAt == 0 { + return NewAppError("SharedChannelUser.IsValid", "model.channel.is_valid.create_at.app_error", nil, "", http.StatusBadRequest) + } + return nil +} + +// SharedChannelAttachment stores a lastSyncAt timestamp on behalf of a remote cluster for +// each file attachment that has been synchronized. +type SharedChannelAttachment struct { + Id string `json:"id"` + FileId string `json:"file_id"` + RemoteId string `json:"remote_id"` + CreateAt int64 `json:"create_at"` + LastSyncAt int64 `json:"last_sync_at"` +} + +func (scf *SharedChannelAttachment) PreSave() { + if scf.Id == "" { + scf.Id = NewId() + } + if scf.CreateAt == 0 { + scf.CreateAt = GetMillis() + scf.LastSyncAt = scf.CreateAt + } else { + scf.LastSyncAt = GetMillis() + } +} + +func (scf *SharedChannelAttachment) IsValid() *AppError { + if !IsValidId(scf.Id) { + return NewAppError("SharedChannelAttachment.IsValid", "model.channel.is_valid.id.app_error", nil, "Id="+scf.Id, http.StatusBadRequest) + } + + if !IsValidId(scf.FileId) { + return NewAppError("SharedChannelAttachment.IsValid", "model.channel.is_valid.id.app_error", nil, "FileId="+scf.FileId, http.StatusBadRequest) + } + + if !IsValidId(scf.RemoteId) { + return NewAppError("SharedChannelAttachment.IsValid", "model.channel.is_valid.id.app_error", nil, "RemoteId="+scf.RemoteId, http.StatusBadRequest) + } + + if scf.CreateAt == 0 { + return NewAppError("SharedChannelAttachment.IsValid", "model.channel.is_valid.create_at.app_error", nil, "", http.StatusBadRequest) + } + return nil +} + +type SharedChannelFilterOpts struct { + TeamId string + CreatorId string + ExcludeHome bool + ExcludeRemote bool +} + +type SharedChannelRemoteFilterOpts struct { + ChannelId string + RemoteId string + InclUnconfirmed bool +} diff --git a/model/shared_channel_test.go b/model/shared_channel_test.go new file mode 100644 index 00000000000..d4664c0dda6 --- /dev/null +++ b/model/shared_channel_test.go @@ -0,0 +1,87 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSharedChannelJson(t *testing.T) { + o := SharedChannel{ChannelId: NewId(), ShareName: NewId()} + json := o.ToJson() + ro, err := SharedChannelFromJson(strings.NewReader(json)) + + require.NoError(t, err) + require.Equal(t, o.ChannelId, ro.ChannelId) + require.Equal(t, o.ShareName, ro.ShareName) +} + +func TestSharedChannelIsValid(t *testing.T) { + id := NewId() + now := GetMillis() + data := []struct { + name string + sc *SharedChannel + valid bool + }{ + {name: "Zero value", sc: &SharedChannel{}, valid: false}, + {name: "Missing team_id", sc: &SharedChannel{ChannelId: id}, valid: false}, + {name: "Missing create_at", sc: &SharedChannel{ChannelId: id, TeamId: id}, valid: false}, + {name: "Missing update_at", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now}, valid: false}, + {name: "Missing share_name", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now}, valid: false}, + {name: "Invalid share_name", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now, + ShareName: "@test@"}, valid: false}, + {name: "Too long share_name", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now, + ShareName: strings.Repeat("01234567890", 100)}, valid: false}, + {name: "Missing creator_id", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now, + ShareName: "test"}, valid: false}, + {name: "Missing remote_id", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now, + ShareName: "test", CreatorId: id}, valid: false}, + {name: "Valid shared channel", sc: &SharedChannel{ChannelId: id, TeamId: id, CreateAt: now, UpdateAt: now, + ShareName: "test", CreatorId: id, RemoteId: id}, valid: true}, + } + + for _, item := range data { + err := item.sc.IsValid() + if item.valid { + assert.Nil(t, err, item.name) + } else { + assert.NotNil(t, err, item.name) + } + } +} + +func TestSharedChannelPreSave(t *testing.T) { + now := GetMillis() + + o := SharedChannel{ChannelId: NewId(), ShareName: "test"} + o.PreSave() + + require.GreaterOrEqual(t, o.CreateAt, now) + require.GreaterOrEqual(t, o.UpdateAt, now) +} + +func TestSharedChannelPreUpdate(t *testing.T) { + now := GetMillis() + + o := SharedChannel{ChannelId: NewId(), ShareName: "test"} + o.PreUpdate() + + require.GreaterOrEqual(t, o.UpdateAt, now) +} + +func TestSharedChannelRemoteJson(t *testing.T) { + o := SharedChannelRemote{Id: NewId(), ChannelId: NewId(), Description: "Test"} + json := o.ToJson() + ro, err := SharedChannelRemoteFromJson(strings.NewReader(json)) + + require.NoError(t, err) + require.Equal(t, o.Id, ro.Id) + require.Equal(t, o.ChannelId, ro.ChannelId) + require.Equal(t, o.Description, ro.Description) +} diff --git a/model/upload_session.go b/model/upload_session.go index 05994b9e935..c5f09083e2b 100644 --- a/model/upload_session.go +++ b/model/upload_session.go @@ -42,6 +42,10 @@ type UploadSession struct { // The amount of received data in bytes. If equal to FileSize it means the // upload has finished. FileOffset int64 `json:"file_offset"` + // Id of remote cluster if uploading for shared channel + RemoteId string `json:"remote_id"` + // Requested file id if uploading for shared channel + ReqFileId string `json:"req_file_id"` } // ToJson serializes the UploadSession into JSON and returns it as string. diff --git a/model/user.go b/model/user.go index 6d7bcf509f4..c81a4e3da15 100644 --- a/model/user.go +++ b/model/user.go @@ -90,6 +90,7 @@ type User struct { Timezone StringMap `json:"timezone"` MfaActive bool `json:"mfa_active,omitempty"` MfaSecret string `json:"mfa_secret,omitempty"` + RemoteId *string `json:"remote_id,omitempty"` LastActivityAt int64 `db:"-" json:"last_activity_at,omitempty"` IsBot bool `db:"-" json:"is_bot,omitempty"` BotDescription string `db:"-" json:"bot_description,omitempty"` @@ -124,6 +125,7 @@ type UserPatch struct { NotifyProps StringMap `json:"notify_props,omitempty"` Locale *string `json:"locale"` Timezone StringMap `json:"timezone"` + RemoteId *string `json:"remote_id"` } //msgp:ignore UserAuth @@ -512,6 +514,10 @@ func (u *User) Patch(patch *UserPatch) { if patch.Timezone != nil { u.Timezone = patch.Timezone } + + if patch.RemoteId != nil { + u.RemoteId = patch.RemoteId + } } // ToJson convert a User to a json string @@ -734,6 +740,11 @@ func (u *User) GetPreferredTimezone() string { return GetPreferredTimezone(u.Timezone) } +// IsRemote returns true if the user belongs to a remote cluster (has RemoteId). +func (u *User) IsRemote() bool { + return u.RemoteId != nil && *u.RemoteId != "" +} + func (u *User) ToPatch() *UserPatch { return &UserPatch{ Username: &u.Username, Password: &u.Password, diff --git a/model/user_serial_gen.go b/model/user_serial_gen.go index 7caa98f5f45..fb40b577b19 100644 --- a/model/user_serial_gen.go +++ b/model/user_serial_gen.go @@ -17,8 +17,8 @@ func (z *User) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err) return } - if zb0001 != 31 { - err = msgp.ArrayError{Wanted: 31, Got: zb0001} + if zb0001 != 32 { + err = msgp.ArrayError{Wanted: 32, Got: zb0001} return } z.Id, err = dc.ReadString() @@ -158,6 +158,23 @@ func (z *User) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "MfaSecret") return } + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "RemoteId") + return + } + z.RemoteId = nil + } else { + if z.RemoteId == nil { + z.RemoteId = new(string) + } + *z.RemoteId, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "RemoteId") + return + } + } z.LastActivityAt, err = dc.ReadInt64() if err != nil { err = msgp.WrapError(err, "LastActivityAt") @@ -193,8 +210,8 @@ func (z *User) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *User) EncodeMsg(en *msgp.Writer) (err error) { - // array header, size 31 - err = en.Append(0xdc, 0x0, 0x1f) + // array header, size 32 + err = en.Append(0xdc, 0x0, 0x20) if err != nil { return } @@ -330,6 +347,18 @@ func (z *User) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "MfaSecret") return } + if z.RemoteId == nil { + err = en.WriteNil() + if err != nil { + return + } + } else { + err = en.WriteString(*z.RemoteId) + if err != nil { + err = msgp.WrapError(err, "RemoteId") + return + } + } err = en.WriteInt64(z.LastActivityAt) if err != nil { err = msgp.WrapError(err, "LastActivityAt") @@ -366,8 +395,8 @@ func (z *User) EncodeMsg(en *msgp.Writer) (err error) { // MarshalMsg implements msgp.Marshaler func (z *User) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // array header, size 31 - o = append(o, 0xdc, 0x0, 0x1f) + // array header, size 32 + o = append(o, 0xdc, 0x0, 0x20) o = msgp.AppendString(o, z.Id) o = msgp.AppendInt64(o, z.CreateAt) o = msgp.AppendInt64(o, z.UpdateAt) @@ -409,6 +438,11 @@ func (z *User) MarshalMsg(b []byte) (o []byte, err error) { } o = msgp.AppendBool(o, z.MfaActive) o = msgp.AppendString(o, z.MfaSecret) + if z.RemoteId == nil { + o = msgp.AppendNil(o) + } else { + o = msgp.AppendString(o, *z.RemoteId) + } o = msgp.AppendInt64(o, z.LastActivityAt) o = msgp.AppendBool(o, z.IsBot) o = msgp.AppendString(o, z.BotDescription) @@ -426,8 +460,8 @@ func (z *User) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err) return } - if zb0001 != 31 { - err = msgp.ArrayError{Wanted: 31, Got: zb0001} + if zb0001 != 32 { + err = msgp.ArrayError{Wanted: 32, Got: zb0001} return } z.Id, bts, err = msgp.ReadStringBytes(bts) @@ -566,6 +600,22 @@ func (z *User) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "MfaSecret") return } + if msgp.IsNil(bts) { + bts, err = msgp.ReadNilBytes(bts) + if err != nil { + return + } + z.RemoteId = nil + } else { + if z.RemoteId == nil { + z.RemoteId = new(string) + } + *z.RemoteId, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "RemoteId") + return + } + } z.LastActivityAt, bts, err = msgp.ReadInt64Bytes(bts) if err != nil { err = msgp.WrapError(err, "LastActivityAt") @@ -608,7 +658,13 @@ func (z *User) Msgsize() (s int) { } else { s += msgp.StringPrefixSize + len(*z.AuthData) } - s += msgp.StringPrefixSize + len(z.AuthService) + msgp.StringPrefixSize + len(z.Email) + msgp.BoolSize + msgp.StringPrefixSize + len(z.Nickname) + msgp.StringPrefixSize + len(z.FirstName) + msgp.StringPrefixSize + len(z.LastName) + msgp.StringPrefixSize + len(z.Position) + msgp.StringPrefixSize + len(z.Roles) + msgp.BoolSize + z.Props.Msgsize() + z.NotifyProps.Msgsize() + msgp.Int64Size + msgp.Int64Size + msgp.IntSize + msgp.StringPrefixSize + len(z.Locale) + z.Timezone.Msgsize() + msgp.BoolSize + msgp.StringPrefixSize + len(z.MfaSecret) + msgp.Int64Size + msgp.BoolSize + msgp.StringPrefixSize + len(z.BotDescription) + msgp.Int64Size + msgp.StringPrefixSize + len(z.TermsOfServiceId) + msgp.Int64Size + s += msgp.StringPrefixSize + len(z.AuthService) + msgp.StringPrefixSize + len(z.Email) + msgp.BoolSize + msgp.StringPrefixSize + len(z.Nickname) + msgp.StringPrefixSize + len(z.FirstName) + msgp.StringPrefixSize + len(z.LastName) + msgp.StringPrefixSize + len(z.Position) + msgp.StringPrefixSize + len(z.Roles) + msgp.BoolSize + z.Props.Msgsize() + z.NotifyProps.Msgsize() + msgp.Int64Size + msgp.Int64Size + msgp.IntSize + msgp.StringPrefixSize + len(z.Locale) + z.Timezone.Msgsize() + msgp.BoolSize + msgp.StringPrefixSize + len(z.MfaSecret) + if z.RemoteId == nil { + s += msgp.NilSize + } else { + s += msgp.StringPrefixSize + len(*z.RemoteId) + } + s += msgp.Int64Size + msgp.BoolSize + msgp.StringPrefixSize + len(z.BotDescription) + msgp.Int64Size + msgp.StringPrefixSize + len(z.TermsOfServiceId) + msgp.Int64Size return } diff --git a/model/utils.go b/model/utils.go index 147e7dba561..4827fb5a9ea 100644 --- a/model/utils.go +++ b/model/utils.go @@ -195,6 +195,11 @@ func GetMillisForTime(thisTime time.Time) int64 { return thisTime.UnixNano() / int64(time.Millisecond) } +// GetTimeForMillis is a convenience method to get time.Time for milliseconds since epoch. +func GetTimeForMillis(millis int64) time.Time { + return time.Unix(0, millis*int64(time.Millisecond)) +} + // PadDateStringZeros is a convenience method to pad 2 digit date parts with zeros to meet ISO 8601 format func PadDateStringZeros(dateString string) string { parts := strings.Split(dateString, "-") diff --git a/model/utils_test.go b/model/utils_test.go index 9a7c819bf4e..1764a46a255 100644 --- a/model/utils_test.go +++ b/model/utils_test.go @@ -40,6 +40,14 @@ func TestGetMillisForTime(t *testing.T) { require.Equalf(t, thisTimeMillis, result, "millis are not the same: %d and %d", thisTimeMillis, result) } +func TestGetTimeForMillis(t *testing.T) { + thisTimeMillis := int64(1471219200000) + thisTime := time.Date(2016, time.August, 15, 0, 0, 0, 0, time.UTC) + + result := GetTimeForMillis(thisTimeMillis) + require.True(t, thisTime.Equal(result)) +} + func TestPadDateStringZeros(t *testing.T) { for _, testCase := range []struct { Name string diff --git a/services/remotecluster/error.go b/services/remotecluster/error.go new file mode 100644 index 00000000000..db693740a4a --- /dev/null +++ b/services/remotecluster/error.go @@ -0,0 +1,24 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import "fmt" + +type BufferFullError struct { + capacity int +} + +func NewBufferFullError(capacity int) BufferFullError { + return BufferFullError{ + capacity: capacity, + } +} + +func (e BufferFullError) Capacity() int { + return e.capacity +} + +func (e BufferFullError) Error() string { + return fmt.Sprintf("buffer capacity (%d) exceeded", e.capacity) +} diff --git a/services/remotecluster/invitation.go b/services/remotecluster/invitation.go new file mode 100644 index 00000000000..744928ae8ea --- /dev/null +++ b/services/remotecluster/invitation.go @@ -0,0 +1,82 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/mattermost/mattermost-server/v5/model" +) + +// AcceptInvitation is called when accepting an invitation to connect with a remote cluster. +func (rcs *Service) AcceptInvitation(invite *model.RemoteClusterInvite, name string, creatorId string, teamId string, siteURL string) (*model.RemoteCluster, error) { + rc := &model.RemoteCluster{ + RemoteId: invite.RemoteId, + RemoteTeamId: invite.RemoteTeamId, + DisplayName: name, + Token: model.NewId(), + RemoteToken: invite.Token, + SiteURL: invite.SiteURL, + CreatorId: creatorId, + } + + rcSaved, err := rcs.server.GetStore().RemoteCluster().Save(rc) + if err != nil { + return nil, err + } + + // confirm the invitation with the originating site + frame, err := makeConfirmFrame(rcSaved, teamId, siteURL) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/%s", rcSaved.SiteURL, ConfirmInviteURL) + + resp, err := rcs.sendFrameToRemote(PingTimeout, rc, frame, url) + if err != nil { + rcs.server.GetStore().RemoteCluster().Delete(rcSaved.RemoteId) + return nil, err + } + + var response Response + err = json.Unmarshal(resp, &response) + if err != nil { + rcs.server.GetStore().RemoteCluster().Delete(rcSaved.RemoteId) + return nil, fmt.Errorf("invalid response from remote server: %w", err) + } + + if !response.IsSuccess() { + rcs.server.GetStore().RemoteCluster().Delete(rcSaved.RemoteId) + return nil, errors.New(response.Err) + } + + // issue the first ping right away. The goroutine will exit when ping completes or PingTimeout exceeded. + go rcs.pingRemote(rcSaved) + + return rcSaved, nil +} + +func makeConfirmFrame(rc *model.RemoteCluster, teamId string, siteURL string) (*model.RemoteClusterFrame, error) { + confirm := model.RemoteClusterInvite{ + RemoteId: rc.RemoteId, + RemoteTeamId: teamId, + SiteURL: siteURL, + Token: rc.Token, + } + confirmRaw, err := json.Marshal(confirm) + if err != nil { + return nil, err + } + + msg := model.NewRemoteClusterMsg(InvitationTopic, confirmRaw) + + frame := &model.RemoteClusterFrame{ + RemoteId: rc.RemoteId, + Msg: msg, + } + return frame, nil +} diff --git a/services/remotecluster/mocks_test.go b/services/remotecluster/mocks_test.go new file mode 100644 index 00000000000..5607828ff7f --- /dev/null +++ b/services/remotecluster/mocks_test.go @@ -0,0 +1,104 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "fmt" + "strings" + "testing" + + "go.uber.org/zap/zapcore" + + "github.com/mattermost/mattermost-server/v5/einterfaces" + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/plugin/plugintest/mock" + "github.com/mattermost/mattermost-server/v5/shared/mlog" + "github.com/mattermost/mattermost-server/v5/store" + "github.com/mattermost/mattermost-server/v5/store/storetest/mocks" +) + +type mockServer struct { + remotes []*model.RemoteCluster + logger *mockLogger +} + +func newMockServer(t *testing.T, remotes []*model.RemoteCluster) *mockServer { + return &mockServer{ + remotes: remotes, + logger: &mockLogger{t: t}, + } +} + +func (ms *mockServer) Config() *model.Config { return nil } +func (ms *mockServer) GetMetrics() einterfaces.MetricsInterface { return nil } +func (ms *mockServer) IsLeader() bool { return true } +func (ms *mockServer) AddClusterLeaderChangedListener(listener func()) string { return model.NewId() } +func (ms *mockServer) RemoveClusterLeaderChangedListener(id string) {} +func (ms *mockServer) GetLogger() mlog.LoggerIFace { + return ms.logger +} +func (ms *mockServer) GetStore() store.Store { + anyFilter := mock.MatchedBy(func(filter model.RemoteClusterQueryFilter) bool { + return true + }) + + remoteClusterStoreMock := &mocks.RemoteClusterStore{} + remoteClusterStoreMock.On("GetByTopic", "share").Return(ms.remotes, nil) + remoteClusterStoreMock.On("GetAll", anyFilter).Return(ms.remotes, nil) + + storeMock := &mocks.Store{} + storeMock.On("RemoteCluster").Return(remoteClusterStoreMock) + return storeMock +} + +type mockLogger struct { + t *testing.T +} + +func (ml *mockLogger) IsLevelEnabled(level mlog.LogLevel) bool { + return true +} +func (ml *mockLogger) Debug(s string, flds ...mlog.Field) { + ml.t.Log("debug", s, fieldsToStrings(flds)) +} +func (ml *mockLogger) Info(s string, flds ...mlog.Field) { + ml.t.Log("info", s, fieldsToStrings(flds)) +} +func (ml *mockLogger) Warn(s string, flds ...mlog.Field) { + ml.t.Log("warn", s, fieldsToStrings(flds)) +} +func (ml *mockLogger) Error(s string, flds ...mlog.Field) { + ml.t.Log("error", s, fieldsToStrings(flds)) +} +func (ml *mockLogger) Critical(s string, flds ...mlog.Field) { + ml.t.Log("crit", s, fieldsToStrings(flds)) +} +func (ml *mockLogger) Log(level mlog.LogLevel, s string, flds ...mlog.Field) { + ml.t.Log(level.Name, s, fieldsToStrings(flds)) +} +func (ml *mockLogger) LogM(levels []mlog.LogLevel, s string, flds ...mlog.Field) { + ml.t.Log(levelsToString(levels), s, fieldsToStrings(flds)) +} + +func levelsToString(levels []mlog.LogLevel) string { + sb := strings.Builder{} + for _, l := range levels { + sb.WriteString(l.Name) + sb.WriteString(",") + } + return sb.String() +} + +func fieldsToStrings(fields []mlog.Field) []string { + encoder := zapcore.NewMapObjectEncoder() + for _, zapField := range fields { + zapField.AddTo(encoder) + } + + var result []string + for k, v := range encoder.Fields { + result = append(result, fmt.Sprintf("%s:%v", k, v)) + } + return result +} diff --git a/services/remotecluster/ping.go b/services/remotecluster/ping.go new file mode 100644 index 00000000000..87d0d55d03f --- /dev/null +++ b/services/remotecluster/ping.go @@ -0,0 +1,174 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +// pingLoop periodically sends a ping to all remote clusters. +func (rcs *Service) pingLoop(done <-chan struct{}) { + pingChan := make(chan *model.RemoteCluster, MaxConcurrentSends*2) + + // create a thread pool to send pings concurrently to remotes. + for i := 0; i < MaxConcurrentSends; i++ { + go rcs.pingEmitter(pingChan, done) + } + + go rcs.pingGenerator(pingChan, done) +} + +func (rcs *Service) pingGenerator(pingChan chan *model.RemoteCluster, done <-chan struct{}) { + defer close(pingChan) + + for { + start := time.Now() + + // get all remotes, including any previously offline. + remotes, err := rcs.server.GetStore().RemoteCluster().GetAll(model.RemoteClusterQueryFilter{}) + if err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Ping remote cluster failed (could not get list of remotes)", mlog.Err(err)) + select { + case <-time.After(PingFreq): + continue + case <-done: + return + } + } + + for _, rc := range remotes { + if rc.SiteURL != "" { // filter out unconfirmed invites + pingChan <- rc + } + } + + // try to maintain frequency + elapsed := time.Since(start) + if elapsed < PingFreq { + sleep := time.Until(start.Add(PingFreq)) + select { + case <-time.After(sleep): + case <-done: + return + } + } + } +} + +// pingEmitter pulls Remotes from the ping queue (pingChan) and pings them. +// Pinging a remote cannot take longer than PingTimeoutMillis. +func (rcs *Service) pingEmitter(pingChan <-chan *model.RemoteCluster, done <-chan struct{}) { + for { + select { + case rc := <-pingChan: + if rc == nil { + return + } + + online := rc.IsOnline() + + if err := rcs.pingRemote(rc); err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceWarn, "Remote cluster ping failed", + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rc.RemoteId), + mlog.Err(err), + ) + } + + if online != rc.IsOnline() { + if metrics := rcs.server.GetMetrics(); metrics != nil { + metrics.IncrementRemoteClusterConnStateChangeCounter(rc.RemoteId, rc.IsOnline()) + } + rcs.fireConnectionStateChgEvent(rc) + } + case <-done: + return + } + } +} + +// pingRemote make a synchronous ping to a remote cluster. Return is error if ping is +// unsuccessful and nil on success. +func (rcs *Service) pingRemote(rc *model.RemoteCluster) error { + frame, err := makePingFrame(rc) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", rc.SiteURL, PingURL) + + resp, err := rcs.sendFrameToRemote(PingTimeout, rc, frame, url) + if err != nil { + return err + } + + ping := model.RemoteClusterPing{} + err = json.Unmarshal(resp, &ping) + if err != nil { + return err + } + + if err := rcs.server.GetStore().RemoteCluster().SetLastPingAt(rc.RemoteId); err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Failed to update LastPingAt for remote cluster", + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rc.RemoteId), + mlog.Err(err), + ) + } + rc.LastPingAt = model.GetMillis() + + if metrics := rcs.server.GetMetrics(); metrics != nil { + sentAt := time.Unix(0, ping.SentAt*int64(time.Millisecond)) + elapsed := time.Since(sentAt).Seconds() + metrics.ObserveRemoteClusterPingDuration(rc.RemoteId, elapsed) + + // we approximate clock skew between remotes. + skew := elapsed/2 - float64(ping.RecvAt-ping.SentAt)/1000 + metrics.ObserveRemoteClusterClockSkew(rc.RemoteId, skew) + } + + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceDebug, "Remote cluster ping", + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rc.RemoteId), + mlog.Int64("SentAt", ping.SentAt), + mlog.Int64("RecvAt", ping.RecvAt), + mlog.Int64("Diff", ping.RecvAt-ping.SentAt), + ) + return nil +} + +func makePingFrame(rc *model.RemoteCluster) (*model.RemoteClusterFrame, error) { + ping := model.RemoteClusterPing{ + SentAt: model.GetMillis(), + } + pingRaw, err := json.Marshal(ping) + if err != nil { + return nil, err + } + + msg := model.NewRemoteClusterMsg(PingTopic, pingRaw) + + frame := &model.RemoteClusterFrame{ + RemoteId: rc.RemoteId, + Msg: msg, + } + return frame, nil +} + +func (rcs *Service) fireConnectionStateChgEvent(rc *model.RemoteCluster) { + rcs.mux.RLock() + listeners := make([]ConnectionStateListener, 0, len(rcs.connectionStateListeners)) + for _, l := range rcs.connectionStateListeners { + listeners = append(listeners, l) + } + rcs.mux.RUnlock() + + for _, l := range listeners { + l(rc, rc.IsOnline()) + } +} diff --git a/services/remotecluster/ping_test.go b/services/remotecluster/ping_test.go new file mode 100644 index 00000000000..e3c32291335 --- /dev/null +++ b/services/remotecluster/ping_test.go @@ -0,0 +1,133 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wiggin77/merror" + + "github.com/mattermost/mattermost-server/v5/model" +) + +const ( + Recent = 60000 +) + +func TestPing(t *testing.T) { + disablePing = false + + t.Run("No error", func(t *testing.T) { + var countWebReq int32 + merr := merror.New() + + wg := &sync.WaitGroup{} + wg.Add(NumRemotes) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() + defer w.WriteHeader(200) + atomic.AddInt32(&countWebReq, 1) + + frame, err := model.RemoteClusterFrameFromJSON(r.Body) + if err != nil { + merr.Append(err) + return + } + if len(frame.Msg.Payload) == 0 { + merr.Append(fmt.Errorf("Payload should not be empty; remote_id=%s", frame.RemoteId)) + return + } + + ping, err := model.RemoteClusterPingFromRawJSON(frame.Msg.Payload) + if err != nil { + merr.Append(err) + return + } + if !checkRecent(ping.SentAt, Recent) { + merr.Append(fmt.Errorf("timestamp out of range, got %d", ping.SentAt)) + return + } + if ping.RecvAt != 0 { + merr.Append(fmt.Errorf("timestamp should be 0, got %d", ping.RecvAt)) + return + } + })) + defer ts.Close() + + mockServer := newMockServer(t, makeRemoteClusters(NumRemotes, ts.URL)) + service, err := NewRemoteClusterService(mockServer) + require.NoError(t, err) + + err = service.Start() + require.NoError(t, err) + defer service.Shutdown() + + wg.Wait() + + assert.NoError(t, merr.ErrorOrNil()) + + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countWebReq)) + t.Log(fmt.Sprintf("%d web requests counted; %d expected", + atomic.LoadInt32(&countWebReq), NumRemotes)) + }) + + t.Run("HTTP errors", func(t *testing.T) { + var countWebReq int32 + merr := merror.New() + + wg := &sync.WaitGroup{} + wg.Add(NumRemotes) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() + atomic.AddInt32(&countWebReq, 1) + + frame, err := model.RemoteClusterFrameFromJSON(r.Body) + if err != nil { + merr.Append(err) + } + ping, err := model.RemoteClusterPingFromRawJSON(frame.Msg.Payload) + if err != nil { + merr.Append(err) + } + if !checkRecent(ping.SentAt, Recent) { + merr.Append(fmt.Errorf("timestamp out of range, got %d", ping.SentAt)) + } + if ping.RecvAt != 0 { + merr.Append(fmt.Errorf("timestamp should be 0, got %d", ping.RecvAt)) + } + w.WriteHeader(500) + })) + defer ts.Close() + + mockServer := newMockServer(t, makeRemoteClusters(NumRemotes, ts.URL)) + service, err := NewRemoteClusterService(mockServer) + require.NoError(t, err) + + err = service.Start() + require.NoError(t, err) + defer service.Shutdown() + + wg.Wait() + + assert.Nil(t, merr.ErrorOrNil()) + + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countWebReq)) + t.Log(fmt.Sprintf("%d web requests counted; %d expected", + atomic.LoadInt32(&countWebReq), NumRemotes)) + }) +} + +func checkRecent(millis int64, within int64) bool { + now := model.GetMillis() + return millis > now-within && millis < now+within +} diff --git a/services/remotecluster/recv.go b/services/remotecluster/recv.go new file mode 100644 index 00000000000..cafa7d206c7 --- /dev/null +++ b/services/remotecluster/recv.go @@ -0,0 +1,53 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "fmt" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +// ReceiveIncomingMsg is called by the Rest API layer, or websocket layer (future), when a Remote Cluster +// message is received. Here we route the message to any topic listeners. +// `rc` and `msg` cannot be nil. +func (rcs *Service) ReceiveIncomingMsg(rc *model.RemoteCluster, msg model.RemoteClusterMsg) Response { + rcs.mux.RLock() + defer rcs.mux.RUnlock() + + if metrics := rcs.server.GetMetrics(); metrics != nil { + metrics.IncrementRemoteClusterMsgReceivedCounter(rc.RemoteId) + } + + rcSanitized := *rc + rcSanitized.Token = "" + rcSanitized.RemoteToken = "" + + var response Response + response.Status = ResponseStatusOK + + listeners := rcs.getTopicListeners(msg.Topic) + + for _, l := range listeners { + if err := callback(l, msg, &rcSanitized, &response); err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Error from remote cluster message listener", + mlog.String("msgId", msg.Id), mlog.String("topic", msg.Topic), mlog.String("remote", rc.DisplayName), mlog.Err(err)) + + response.Status = ResponseStatusFail + response.Err = err.Error() + } + } + return response +} + +func callback(listener TopicListener, msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + err = listener(msg, rc, resp) + return +} diff --git a/services/remotecluster/response.go b/services/remotecluster/response.go new file mode 100644 index 00000000000..04e97d7a601 --- /dev/null +++ b/services/remotecluster/response.go @@ -0,0 +1,30 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "encoding/json" +) + +// Response represents the bytes replied from a remote server when a message is sent. +type Response struct { + Status string `json:"status"` + Err string `json:"err"` + Payload json.RawMessage `json:"payload"` +} + +// IsSuccess returns true if the response status indicates success. +func (r *Response) IsSuccess() bool { + return r.Status == ResponseStatusOK +} + +// SetPayload serializes an arbitrary struct as a RawMessage. +func (r *Response) SetPayload(v interface{}) error { + raw, err := json.Marshal(v) + if err != nil { + return err + } + r.Payload = raw + return nil +} diff --git a/services/remotecluster/send.go b/services/remotecluster/send.go new file mode 100644 index 00000000000..df68f97b005 --- /dev/null +++ b/services/remotecluster/send.go @@ -0,0 +1,56 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "context" + "hash/fnv" +) + +// enqueueTask adds a task to one of the send channels based on remoteId. +// +// There are a number of send channels (`MaxConcurrentSends`) to allow for sending to multiple +// remotes concurrently, while preserving message order for each remote. +func (rcs *Service) enqueueTask(ctx context.Context, remoteId string, task interface{}) error { + if ctx == nil { + ctx = context.Background() + } + + h := hash(remoteId) + idx := h % uint32(len(rcs.send)) + + select { + case rcs.send[idx] <- task: + return nil + case <-ctx.Done(): + return NewBufferFullError(cap(rcs.send)) + } +} + +func hash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// sendLoop is called by each goroutine created for the send pool and waits for sendTask's until the +// done channel is signalled. +// +// Each goroutine in the pool is assigned a specific channel, and tasks are placed in the +// channel corresponding to the remoteId. +func (rcs *Service) sendLoop(idx int, done chan struct{}) { + for { + select { + case t := <-rcs.send[idx]: + switch task := t.(type) { + case sendMsgTask: + rcs.sendMsg(task) + case sendFileTask: + rcs.sendFile(task) + } + case <-done: + return + } + } +} diff --git a/services/remotecluster/send_test.go b/services/remotecluster/send_test.go new file mode 100644 index 00000000000..e001c7615f0 --- /dev/null +++ b/services/remotecluster/send_test.go @@ -0,0 +1,200 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wiggin77/merror" + + "github.com/mattermost/mattermost-server/v5/model" +) + +const ( + TestTopics = " share incident " + TestTopic = "share" + NumRemotes = 50 + NoteContent = "Woot!!" +) + +type testPayload struct { + Note string `json:"note"` +} + +func TestBroadcastMsg(t *testing.T) { + msgId := model.NewId() + disablePing = true + + t.Run("No error", func(t *testing.T) { + var countCallbacks int32 + var countWebReq int32 + merr := merror.New() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + w.WriteHeader(200) + var resp Response + b, errMarshall := json.Marshal(&resp) + if errMarshall != nil { + merr.Append(errMarshall) + return + } + w.Write(b) + }() + + atomic.AddInt32(&countWebReq, 1) + + frame, appErr := model.RemoteClusterFrameFromJSON(r.Body) + if appErr != nil { + merr.Append(appErr) + return + } + if len(frame.Msg.Payload) == 0 { + merr.Append(fmt.Errorf("webrequest missing Msg.Payload")) + } + if msgId != frame.Msg.Id { + merr.Append(fmt.Errorf("webrequest msgId expected %s, got %s", msgId, frame.Msg.Id)) + return + } + + note := testPayload{} + err := json.Unmarshal(frame.Msg.Payload, ¬e) + if err != nil { + merr.Append(err) + return + } + if note.Note != NoteContent { + merr.Append(fmt.Errorf("webrequest payload expected %s, got %s", NoteContent, note.Note)) + return + } + })) + defer ts.Close() + + mockServer := newMockServer(t, makeRemoteClusters(NumRemotes, ts.URL)) + service, err := NewRemoteClusterService(mockServer) + require.NoError(t, err) + + err = service.Start() + require.NoError(t, err) + defer service.Shutdown() + + wg := &sync.WaitGroup{} + wg.Add(NumRemotes) + + msg := makeRemoteClusterMsg(msgId, NoteContent) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + err = service.BroadcastMsg(ctx, msg, func(msg model.RemoteClusterMsg, remote *model.RemoteCluster, resp *Response, err error) { + defer wg.Done() + atomic.AddInt32(&countCallbacks, 1) + + if err != nil { + merr.Append(err) + } + if msgId != msg.Id { + merr.Append(fmt.Errorf("result callback msgId expected %s, got %s", msgId, msg.Id)) + } + + var note testPayload + err2 := json.Unmarshal(msg.Payload, ¬e) + if err2 != nil { + merr.Append(fmt.Errorf("unmarshal payload error: %w", err2)) + return + } + if note.Note != NoteContent { + merr.Append(fmt.Errorf("compare payload failed: expected '%s', got '%s'", NoteContent, note)) + } + }) + assert.NoError(t, err) + + wg.Wait() + + assert.NoError(t, merr.ErrorOrNil()) + + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countCallbacks)) + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countWebReq)) + t.Log(fmt.Sprintf("%d callbacks counted; %d web requests counted; %d expected", + atomic.LoadInt32(&countCallbacks), atomic.LoadInt32(&countWebReq), NumRemotes)) + }) + + t.Run("HTTP error", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer ts.Close() + + mockServer := newMockServer(t, makeRemoteClusters(NumRemotes, ts.URL)) + service, err := NewRemoteClusterService(mockServer) + require.NoError(t, err) + + err = service.Start() + require.NoError(t, err) + defer service.Shutdown() + + msg := makeRemoteClusterMsg(msgId, NoteContent) + var countCallbacks int32 + var countErrors int32 + wg := &sync.WaitGroup{} + wg.Add(NumRemotes) + + err = service.BroadcastMsg(context.Background(), msg, func(msg model.RemoteClusterMsg, remote *model.RemoteCluster, resp *Response, err error) { + defer wg.Done() + atomic.AddInt32(&countCallbacks, 1) + if err != nil { + atomic.AddInt32(&countErrors, 1) + } + }) + assert.NoError(t, err) + + wg.Wait() + + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countCallbacks)) + assert.Equal(t, int32(NumRemotes), atomic.LoadInt32(&countErrors)) + }) +} + +func makeRemoteClusters(num int, siteURL string) []*model.RemoteCluster { + var remotes []*model.RemoteCluster + for i := 0; i < num; i++ { + rc := makeRemoteCluster(fmt.Sprintf("test cluster %d", i+1), siteURL, TestTopics) + remotes = append(remotes, rc) + } + return remotes +} + +func makeRemoteCluster(name string, siteURL string, topics string) *model.RemoteCluster { + return &model.RemoteCluster{ + RemoteId: model.NewId(), + DisplayName: name, + SiteURL: siteURL, + Token: model.NewId(), + Topics: topics, + CreateAt: model.GetMillis(), + LastPingAt: model.GetMillis(), + CreatorId: model.NewId(), + } +} + +func makeRemoteClusterMsg(id string, note string) model.RemoteClusterMsg { + payload := testPayload{Note: note} + raw, _ := json.Marshal(payload) + + return model.RemoteClusterMsg{ + Id: id, + Topic: TestTopic, + CreateAt: model.GetMillis(), + Payload: raw} +} diff --git a/services/remotecluster/sendfile.go b/services/remotecluster/sendfile.go new file mode 100644 index 00000000000..7d13e2ed7bf --- /dev/null +++ b/services/remotecluster/sendfile.go @@ -0,0 +1,147 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "path" + "time" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/filestore" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +type SendFileResultFunc func(us *model.UploadSession, rc *model.RemoteCluster, resp *Response, err error) + +type sendFileTask struct { + rc *model.RemoteCluster + us *model.UploadSession + fi *model.FileInfo + rp ReaderProvider + f SendFileResultFunc +} + +type ReaderProvider interface { + FileReader(path string) (filestore.ReadCloseSeeker, *model.AppError) +} + +// SendFile asynchronously sends a file to a remote cluster. +// +// `ctx` determines behaviour when the outbound queue is full. A timeout or deadline context will return a +// BufferFullError if the file cannot be enqueued before the timeout. A background context will block indefinitely. +// +// Nil or error return indicates success or failure of file enqueue only. +// +// An optional callback can be provided that receives the response from the remote cluster. The `err` provided to the +// callback is regarding file delivery only. The `resp` contains the decoded bytes returned from the remote. +// If a callback is provided it should return quickly. +func (rcs *Service) SendFile(ctx context.Context, us *model.UploadSession, fi *model.FileInfo, rc *model.RemoteCluster, rp ReaderProvider, f SendFileResultFunc) error { + task := sendFileTask{ + rc: rc, + us: us, + fi: fi, + rp: rp, + f: f, + } + return rcs.enqueueTask(ctx, rc.RemoteId, task) +} + +// sendFile is called when a sendFileTask is popped from the send channel. +func (rcs *Service) sendFile(task sendFileTask) { + // Ensure a panic from the callback does not exit the goroutine. + defer func() { + if r := recover(); r != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Remote Cluster sendFile panic", + mlog.String("remote", task.rc.DisplayName), + mlog.String("uploadId", task.us.Id), + mlog.Any("panic", r), + ) + } + }() + + fi, err := rcs.sendFileToRemote(SendTimeout, task) + var response Response + + if err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Remote Cluster send file failed", + mlog.String("remote", task.rc.DisplayName), + mlog.String("uploadId", task.us.Id), + mlog.Err(err), + ) + response.Status = ResponseStatusFail + response.Err = err.Error() + } else { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceDebug, "Remote Cluster file sent successfully", + mlog.String("remote", task.rc.DisplayName), + mlog.String("uploadId", task.us.Id), + ) + response.Status = ResponseStatusOK + response.SetPayload(fi) + } + + // If callback provided then call it with the results. + if task.f != nil { + task.f(task.us, task.rc, &response, err) + } +} + +func (rcs *Service) sendFileToRemote(timeout time.Duration, task sendFileTask) (*model.FileInfo, error) { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceDebug, "sending file to remote...", + mlog.String("remote", task.rc.DisplayName), + mlog.String("uploadId", task.us.Id), + mlog.String("file_path", task.us.Path), + ) + + r, appErr := task.rp.FileReader(task.fi.Path) // get Reader for the file + if appErr != nil { + return nil, fmt.Errorf("error opening file while sending file to remote %s: %w", task.rc.RemoteId, appErr) + } + defer r.Close() + + u, err := url.Parse(task.rc.SiteURL) + if err != nil { + return nil, fmt.Errorf("invalid siteURL while sending file to remote %s: %w", task.rc.RemoteId, err) + } + u.Path = path.Join(u.Path, model.API_URL_SUFFIX, "remotecluster", "upload", task.us.Id) + + req, err := http.NewRequest("POST", u.String(), r) + if err != nil { + return nil, err + } + + req.Header.Set(model.HEADER_REMOTECLUSTER_ID, task.rc.RemoteId) + req.Header.Set(model.HEADER_REMOTECLUSTER_TOKEN, task.rc.RemoteToken) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + resp, err := rcs.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response: %d - %s", resp.StatusCode, resp.Status) + } + + // body should be a FileInfo + var fi model.FileInfo + if err := json.Unmarshal(body, &fi); err != nil { + return nil, fmt.Errorf("unexpected response body: %w", err) + } + + return &fi, nil +} diff --git a/services/remotecluster/sendmsg.go b/services/remotecluster/sendmsg.go new file mode 100644 index 00000000000..7e6f5067d70 --- /dev/null +++ b/services/remotecluster/sendmsg.go @@ -0,0 +1,180 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "path" + "time" + + "github.com/wiggin77/merror" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +type SendMsgResultFunc func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response, err error) + +type sendMsgTask struct { + rc *model.RemoteCluster + msg model.RemoteClusterMsg + f SendMsgResultFunc +} + +// BroadcastMsg asynchronously sends a message to all remote clusters interested in the message's topic. +// +// `ctx` determines behaviour when the outbound queue is full. A timeout or deadline context will return a +// BufferFullError if the message cannot be enqueued before the timeout. A background context will block indefinitely. +// +// An optional callback can be provided that receives the success or fail result of sending to each remote cluster. +// Success or fail is regarding message delivery only. If a callback is provided it should return quickly. +func (rcs *Service) BroadcastMsg(ctx context.Context, msg model.RemoteClusterMsg, f SendMsgResultFunc) error { + // get list of interested remotes. + filter := model.RemoteClusterQueryFilter{ + Topic: msg.Topic, + } + list, err := rcs.server.GetStore().RemoteCluster().GetAll(filter) + if err != nil { + return err + } + + errs := merror.New() + + for _, rc := range list { + if err := rcs.SendMsg(ctx, msg, rc, f); err != nil { + errs.Append(err) + } + } + return errs.ErrorOrNil() +} + +// SendMsg asynchronously sends a message to a remote cluster. +// +// `ctx` determines behaviour when the outbound queue is full. A timeout or deadline context will return a +// BufferFullError if the message cannot be enqueued before the timeout. A background context will block indefinitely. +// +// Nil or error return indicates success or failure of message enqueue only. +// +// An optional callback can be provided that receives the response from the remote cluster. The `err` provided to the +// callback is regarding response decoding only. The `resp` contains the decoded bytes returned from the remote. +// If a callback is provided it should return quickly. +func (rcs *Service) SendMsg(ctx context.Context, msg model.RemoteClusterMsg, rc *model.RemoteCluster, f SendMsgResultFunc) error { + task := sendMsgTask{ + rc: rc, + msg: msg, + f: f, + } + return rcs.enqueueTask(ctx, rc.RemoteId, task) +} + +// sendMsg is called when a sendMsgTask is popped from the send channel. +func (rcs *Service) sendMsg(task sendMsgTask) { + var errResp error + var response Response + + // Ensure a panic from the callback does not exit the pool goroutine. + defer func() { + if r := recover(); r != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Remote Cluster sendMsg panic", + mlog.String("remote", task.rc.DisplayName), mlog.String("msgId", task.msg.Id), mlog.Any("panic", r)) + } + + if errResp != nil { + response.Err = errResp.Error() + } + + // If callback provided then call it with the results. + if task.f != nil { + task.f(task.msg, task.rc, &response, errResp) + } + }() + + frame := &model.RemoteClusterFrame{ + RemoteId: task.rc.RemoteId, + Msg: task.msg, + } + + u, err := url.Parse(task.rc.SiteURL) + if err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Invalid siteURL while sending message to remote", + mlog.String("remote", task.rc.DisplayName), + mlog.String("msgId", task.msg.Id), + mlog.Err(err), + ) + errResp = err + return + } + u.Path = path.Join(u.Path, SendMsgURL) + + respJSON, err := rcs.sendFrameToRemote(SendTimeout, task.rc, frame, u.String()) + + if err != nil { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceError, "Remote Cluster send message failed", + mlog.String("remote", task.rc.DisplayName), + mlog.String("msgId", task.msg.Id), + mlog.Err(err), + ) + errResp = err + } else { + rcs.server.GetLogger().Log(mlog.LvlRemoteClusterServiceDebug, "Remote Cluster message sent successfully", + mlog.String("remote", task.rc.DisplayName), + mlog.String("msgId", task.msg.Id), + ) + + if err = json.Unmarshal(respJSON, &response); err != nil { + rcs.server.GetLogger().Error("Invalid response sending message to remote cluster", + mlog.String("remote", task.rc.DisplayName), + mlog.Err(err), + ) + errResp = err + } + } +} + +func (rcs *Service) sendFrameToRemote(timeout time.Duration, rc *model.RemoteCluster, frame *model.RemoteClusterFrame, url string) ([]byte, error) { + body, err := json.Marshal(frame) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequest("POST", url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set(model.HEADER_REMOTECLUSTER_ID, rc.RemoteId) + req.Header.Set(model.HEADER_REMOTECLUSTER_TOKEN, rc.RemoteToken) + + resp, err := rcs.httpClient.Do(req.WithContext(ctx)) + if metrics := rcs.server.GetMetrics(); metrics != nil { + if err != nil || resp.StatusCode != http.StatusOK { + metrics.IncrementRemoteClusterMsgErrorsCounter(frame.RemoteId, os.IsTimeout(err)) + } else { + metrics.IncrementRemoteClusterMsgSentCounter(frame.RemoteId) + } + } + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return body, fmt.Errorf("unexpected response: %d - %s", resp.StatusCode, resp.Status) + } + return body, nil +} diff --git a/services/remotecluster/service.go b/services/remotecluster/service.go new file mode 100644 index 00000000000..9824344d48e --- /dev/null +++ b/services/remotecluster/service.go @@ -0,0 +1,261 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/mattermost/mattermost-server/v5/einterfaces" + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" + "github.com/mattermost/mattermost-server/v5/store" +) + +const ( + SendChanBuffer = 50 + RecvChanBuffer = 50 + ResultsChanBuffer = 50 + ResultQueueDrainTimeoutMillis = 10000 + MaxConcurrentSends = 10 + SendMsgURL = "api/v4/remotecluster/msg" + SendTimeout = time.Minute + SendFileTimeout = time.Minute * 5 + PingURL = "api/v4/remotecluster/ping" + PingFreq = time.Minute + PingTimeout = time.Second * 15 + ConfirmInviteURL = "api/v4/remotecluster/confirm_invite" + InvitationTopic = "invitation" + PingTopic = "ping" + ResponseStatusOK = model.STATUS_OK + ResponseStatusFail = model.STATUS_FAIL + InviteExpiresAfter = time.Hour * 48 +) + +var ( + disablePing bool // override for testing +) + +type ServerIface interface { + Config() *model.Config + IsLeader() bool + AddClusterLeaderChangedListener(listener func()) string + RemoveClusterLeaderChangedListener(id string) + GetStore() store.Store + GetLogger() mlog.LoggerIFace + GetMetrics() einterfaces.MetricsInterface +} + +// RemoteClusterServiceIFace is used to allow mocking where a remote cluster service is used (for testing). +// Unfortunately it lives here because the shared channel service, app layer, and server interface all need it. +// Putting it in app layer means shared channel service must import app package. +type RemoteClusterServiceIFace interface { + Shutdown() error + Start() error + Active() bool + AddTopicListener(topic string, listener TopicListener) string + RemoveTopicListener(listenerId string) + AddConnectionStateListener(listener ConnectionStateListener) string + RemoveConnectionStateListener(listenerId string) + SendMsg(ctx context.Context, msg model.RemoteClusterMsg, rc *model.RemoteCluster, f SendMsgResultFunc) error + SendFile(ctx context.Context, us *model.UploadSession, fi *model.FileInfo, rc *model.RemoteCluster, rp ReaderProvider, f SendFileResultFunc) error + AcceptInvitation(invite *model.RemoteClusterInvite, name string, creatorId string, teamId string, siteURL string) (*model.RemoteCluster, error) + ReceiveIncomingMsg(rc *model.RemoteCluster, msg model.RemoteClusterMsg) Response +} + +// TopicListener is a callback signature used to listen for incoming messages for +// a specific topic. +type TopicListener func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response) error + +// ConnectionStateListener is used to listen to remote cluster connection state changes. +type ConnectionStateListener func(rc *model.RemoteCluster, online bool) + +// Service provides inter-cluster communication via topic based messages. +type Service struct { + server ServerIface + httpClient *http.Client + send []chan interface{} + + // everything below guarded by `mux` + mux sync.RWMutex + active bool + leaderListenerId string + topicListeners map[string]map[string]TopicListener // maps topic id to a map of listenerid->listener + connectionStateListeners map[string]ConnectionStateListener // maps listener id to listener + done chan struct{} +} + +// NewRemoteClusterService creates a RemoteClusterService instance. +func NewRemoteClusterService(server ServerIface) (*Service, error) { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 200, + MaxIdleConnsPerHost: 2, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: false, + } + + client := &http.Client{ + Transport: transport, + Timeout: SendTimeout, + } + + service := &Service{ + server: server, + httpClient: client, + topicListeners: make(map[string]map[string]TopicListener), + connectionStateListeners: make(map[string]ConnectionStateListener), + } + + service.send = make([]chan interface{}, MaxConcurrentSends) + for i := range service.send { + service.send[i] = make(chan interface{}, SendChanBuffer) + } + + return service, nil +} + +// Start is called by the server on server start-up. +func (rcs *Service) Start() error { + rcs.mux.Lock() + rcs.leaderListenerId = rcs.server.AddClusterLeaderChangedListener(rcs.onClusterLeaderChange) + rcs.mux.Unlock() + + rcs.onClusterLeaderChange() + + return nil +} + +// Shutdown is called by the server on server shutdown. +func (rcs *Service) Shutdown() error { + rcs.server.RemoveClusterLeaderChangedListener(rcs.leaderListenerId) + rcs.pause() + return nil +} + +// Active returns true if this instance of the remote cluster service is active. +// The active instance is responsible for pinging and sending messages to remotes. +func (rcs *Service) Active() bool { + rcs.mux.Lock() + defer rcs.mux.Unlock() + return rcs.active +} + +// AddTopicListener registers a callback +func (rcs *Service) AddTopicListener(topic string, listener TopicListener) string { + rcs.mux.Lock() + defer rcs.mux.Unlock() + + id := model.NewId() + + listeners, ok := rcs.topicListeners[topic] + if !ok || listeners == nil { + rcs.topicListeners[topic] = make(map[string]TopicListener) + } + rcs.topicListeners[topic][id] = listener + return id +} + +func (rcs *Service) RemoveTopicListener(listenerId string) { + rcs.mux.Lock() + defer rcs.mux.Unlock() + + for topic, listeners := range rcs.topicListeners { + if _, ok := listeners[listenerId]; ok { + delete(listeners, listenerId) + if len(listeners) == 0 { + delete(rcs.topicListeners, topic) + } + break + } + } +} + +func (rcs *Service) getTopicListeners(topic string) []TopicListener { + rcs.mux.RLock() + defer rcs.mux.RUnlock() + + listeners, ok := rcs.topicListeners[topic] + if !ok { + return nil + } + + listenersCopy := make([]TopicListener, 0, len(listeners)) + for _, l := range listeners { + listenersCopy = append(listenersCopy, l) + } + return listenersCopy +} + +func (rcs *Service) AddConnectionStateListener(listener ConnectionStateListener) string { + id := model.NewId() + + rcs.mux.Lock() + defer rcs.mux.Unlock() + + rcs.connectionStateListeners[id] = listener + return id +} + +func (rcs *Service) RemoveConnectionStateListener(listenerId string) { + rcs.mux.Lock() + defer rcs.mux.Unlock() + delete(rcs.connectionStateListeners, listenerId) +} + +// onClusterLeaderChange is called whenever the cluster leader may have changed. +func (rcs *Service) onClusterLeaderChange() { + if rcs.server.IsLeader() { + rcs.resume() + } else { + rcs.pause() + } +} + +func (rcs *Service) resume() { + rcs.mux.Lock() + defer rcs.mux.Unlock() + + if rcs.active { + return // already active + } + rcs.active = true + rcs.done = make(chan struct{}) + + if !disablePing { + rcs.pingLoop(rcs.done) + } + + // create thread pool for concurrent message sending. + for i := range rcs.send { + go rcs.sendLoop(i, rcs.done) + } + + rcs.server.GetLogger().Debug("Remote Cluster Service active") +} + +func (rcs *Service) pause() { + rcs.mux.Lock() + defer rcs.mux.Unlock() + + if !rcs.active { + return // already inactive + } + rcs.active = false + close(rcs.done) + rcs.done = nil + + rcs.server.GetLogger().Debug("Remote Cluster Service inactive") +} diff --git a/services/remotecluster/service_test.go b/services/remotecluster/service_test.go new file mode 100644 index 00000000000..bc8e157fc2b --- /dev/null +++ b/services/remotecluster/service_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package remotecluster + +import ( + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" +) + +func TestService_AddTopicListener(t *testing.T) { + var count int32 + + l1 := func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response) error { + atomic.AddInt32(&count, 1) + return nil + } + l2 := func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response) error { + atomic.AddInt32(&count, 1) + return nil + } + l3 := func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *Response) error { + atomic.AddInt32(&count, 1) + return nil + } + + mockServer := newMockServer(t, makeRemoteClusters(NumRemotes, "")) + service, err := NewRemoteClusterService(mockServer) + require.NoError(t, err) + + l1id := service.AddTopicListener("test", l1) + l2id := service.AddTopicListener("test", l2) + l3id := service.AddTopicListener("different", l3) + + listeners := service.getTopicListeners("test") + assert.Len(t, listeners, 2) + + rc := &model.RemoteCluster{} + msg1 := model.RemoteClusterMsg{Topic: "test"} + msg2 := model.RemoteClusterMsg{Topic: "different"} + + service.ReceiveIncomingMsg(rc, msg1) + assert.Equal(t, int32(2), atomic.LoadInt32(&count)) + + service.ReceiveIncomingMsg(rc, msg2) + assert.Equal(t, int32(3), atomic.LoadInt32(&count)) + + service.RemoveTopicListener(l1id) + service.ReceiveIncomingMsg(rc, msg1) + assert.Equal(t, int32(4), atomic.LoadInt32(&count)) + + service.RemoveTopicListener(l2id) + service.ReceiveIncomingMsg(rc, msg1) + assert.Equal(t, int32(4), atomic.LoadInt32(&count)) + + service.ReceiveIncomingMsg(rc, msg2) + assert.Equal(t, int32(5), atomic.LoadInt32(&count)) + + service.RemoveTopicListener(l3id) + service.ReceiveIncomingMsg(rc, msg1) + service.ReceiveIncomingMsg(rc, msg2) + assert.Equal(t, int32(5), atomic.LoadInt32(&count)) + + listeners = service.getTopicListeners("test") + assert.Empty(t, listeners) +} diff --git a/services/sharedchannel/attachment.go b/services/sharedchannel/attachment.go new file mode 100644 index 00000000000..4bd2a3085dc --- /dev/null +++ b/services/sharedchannel/attachment.go @@ -0,0 +1,183 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +// postToAttachments returns the file attachments for a post that need to be synchronized. +func (scs *Service) postToAttachments(post *model.Post, rc *model.RemoteCluster) ([]*model.FileInfo, error) { + infos := make([]*model.FileInfo, 0) + + fis, err := scs.server.GetStore().FileInfo().GetForPost(post.Id, false, true, true) + if err != nil { + return nil, fmt.Errorf("could not get file info for attachment: %w", err) + } + + for _, fi := range fis { + if scs.shouldSyncAttachment(fi, rc) { + infos = append(infos, fi) + } + } + return infos, nil +} + +// postsToAttachments returns the file attachments for a slice of posts that need to be synchronized. +func (scs *Service) shouldSyncAttachment(fi *model.FileInfo, rc *model.RemoteCluster) bool { + sca, err := scs.server.GetStore().SharedChannel().GetAttachment(fi.Id, rc.RemoteId) + if err != nil { + if _, ok := err.(errNotFound); !ok { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error fetching shared channel attachment", + mlog.String("file_id", fi.Id), + mlog.String("remote_id", rc.RemoteId), + mlog.Err(err), + ) + } + // no record so sync is needed + return true + } + + return sca.LastSyncAt < fi.UpdateAt +} + +// sendAttachmentForRemote asynchronously sends a file attachment to a remote cluster. +func (scs *Service) sendAttachmentForRemote(fi *model.FileInfo, post *model.Post, rc *model.RemoteCluster) error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return fmt.Errorf("cannot update remote cluster for remote id %s; Remote Cluster Service not enabled", rc.RemoteId) + } + + us := &model.UploadSession{ + Id: model.NewId(), + Type: model.UploadTypeAttachment, + UserId: post.UserId, + ChannelId: post.ChannelId, + Filename: fi.Name, + FileSize: fi.Size, + RemoteId: rc.RemoteId, + ReqFileId: fi.Id, + } + + payload, err := json.Marshal(us) + if err != nil { + return err + } + + msg := model.NewRemoteClusterMsg(TopicUploadCreate, payload) + + ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) + defer cancel() + + var usResp model.UploadSession + var respErr error + var wg sync.WaitGroup + wg.Add(1) + + // creating the upload session on the remote server needs to be done synchronously. + err = rcs.SendMsg(ctx, msg, rc, func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { + defer wg.Done() + if err != nil { + respErr = err + return + } + if !resp.IsSuccess() { + respErr = errors.New(resp.Err) + return + } + respErr = json.Unmarshal(resp.Payload, &usResp) + }) + + if err != nil { + return fmt.Errorf("error sending create upload session to remote %s for post %s: %w", rc.RemoteId, post.Id, err) + } + + wg.Wait() + + if respErr != nil { + return fmt.Errorf("invalid create upload session response for remote %s and post %s: %w", rc.RemoteId, post.Id, respErr) + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), remotecluster.SendFileTimeout) + defer cancel2() + + return rcs.SendFile(ctx2, &usResp, fi, rc, scs.app, func(us *model.UploadSession, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { + if err != nil { + return // this means the response could not be parsed; already logged + } + + if !resp.IsSuccess() { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "send file failed", + mlog.String("remote", rc.DisplayName), + mlog.String("uploadId", usResp.Id), + mlog.String("err", resp.Err), + ) + return + } + + // response payload should be a model.FileInfo. + var fi model.FileInfo + if err2 := json.Unmarshal(resp.Payload, &fi); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "invalid file info response after send file", + mlog.String("remote", rc.DisplayName), + mlog.String("uploadId", usResp.Id), + mlog.Err(err2), + ) + return + } + + // save file attachment record in SharedChannelAttachments table + sca := &model.SharedChannelAttachment{ + FileId: fi.Id, + RemoteId: rc.RemoteId, + } + if _, err2 := scs.server.GetStore().SharedChannel().UpsertAttachment(sca); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error saving SharedChannelAttachment", + mlog.String("remote", rc.DisplayName), + mlog.String("uploadId", usResp.Id), + mlog.Err(err2), + ) + return + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "send file successful", + mlog.String("remote", rc.DisplayName), + mlog.String("uploadId", usResp.Id), + ) + }) +} + +// onReceiveUploadCreate is called when a message requesting to create an upload session is received. An upload session is +// created and the id returned in the response. +func (scs *Service) onReceiveUploadCreate(msg model.RemoteClusterMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { + var us model.UploadSession + + if err := json.Unmarshal(msg.Payload, &us); err != nil { + return fmt.Errorf("invalid upload session request: %w", err) + } + + // make sure channel is shared for the remote sender + if _, err := scs.server.GetStore().SharedChannel().GetRemoteByIds(us.ChannelId, rc.RemoteId); err != nil { + return fmt.Errorf("could not validate upload session for remote: %w", err) + } + + us.RemoteId = rc.RemoteId // don't let remotes try to impersonate each other + + // create upload session. + usSaved, appErr := scs.app.CreateUploadSession(&us) + if appErr != nil { + return appErr + } + + response.SetPayload(usSaved) + return nil +} diff --git a/services/sharedchannel/channelinvite.go b/services/sharedchannel/channelinvite.go new file mode 100644 index 00000000000..1ec48d03455 --- /dev/null +++ b/services/sharedchannel/channelinvite.go @@ -0,0 +1,220 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +// channelInviteMsg represents an invitation for a remote cluster to start sharing a channel. +type channelInviteMsg struct { + ChannelId string `json:"channel_id"` + TeamId string `json:"team_id"` + ReadOnly bool `json:"read_only"` + Name string `json:"name"` + DisplayName string `json:"display_name"` + Header string `json:"header"` + Purpose string `json:"purpose"` + Type string `json:"type"` + DirectParticipantIDs []string `json:"direct_participant_ids"` +} + +type InviteOption func(msg *channelInviteMsg) + +func WithDirectParticipantID(participantID string) InviteOption { + return func(msg *channelInviteMsg) { + msg.DirectParticipantIDs = append(msg.DirectParticipantIDs, participantID) + } +} + +// SendChannelInvite asynchronously sends a channel invite to a remote cluster. The remote cluster is +// expected to create a new channel with the same channel id, and respond with status OK. +// If an error occurs on the remote cluster then an ephemeral message is posted to in the channel for userId. +func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, description string, rc *model.RemoteCluster, options ...InviteOption) error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return fmt.Errorf("cannot invite remote cluster for channel id %s; Remote Cluster Service not enabled", channel.Id) + } + + sc, err := scs.server.GetStore().SharedChannel().Get(channel.Id) + if err != nil { + return err + } + + invite := channelInviteMsg{ + ChannelId: channel.Id, + TeamId: rc.RemoteTeamId, + ReadOnly: sc.ReadOnly, + Name: sc.ShareName, + DisplayName: sc.ShareDisplayName, + Header: sc.ShareHeader, + Purpose: sc.SharePurpose, + Type: channel.Type, + } + + for _, option := range options { + option(&invite) + } + + json, err := json.Marshal(invite) + if err != nil { + return err + } + + msg := model.NewRemoteClusterMsg(TopicChannelInvite, json) + + ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) + defer cancel() + + return rcs.SendMsg(ctx, msg, rc, func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { + if err != nil || !resp.IsSuccess() { + scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error sending channel invite for %s: %s", rc.DisplayName, combineErrors(err, resp.Err))) + return + } + + scr := &model.SharedChannelRemote{ + ChannelId: sc.ChannelId, + Description: description, + CreatorId: userId, + RemoteId: rc.RemoteId, + IsInviteAccepted: true, + IsInviteConfirmed: true, + } + if _, err = scs.server.GetStore().SharedChannel().SaveRemote(scr); err != nil { + scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, err)) + return + } + scs.NotifyChannelChanged(sc.ChannelId) + scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("`%s` has been added to channel.", rc.DisplayName)) + }) +} + +func combineErrors(err error, serror string) string { + var sb strings.Builder + if err != nil { + sb.WriteString(err.Error()) + } + if serror != "" { + if sb.Len() > 0 { + sb.WriteString("; ") + } + sb.WriteString(serror) + } + return sb.String() +} + +func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model.RemoteCluster, _ *remotecluster.Response) error { + if len(msg.Payload) == 0 { + return nil + } + + var invite channelInviteMsg + + if err := json.Unmarshal(msg.Payload, &invite); err != nil { + return fmt.Errorf("invalid channel invite: %w", err) + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Channel invite received", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", invite.ChannelId), + mlog.String("channel_name", invite.Name), + mlog.String("team_id", invite.TeamId), + ) + + // create channel if it doesn't exist; the channel may already exist, such as if it was shared then unshared at some point. + channel, err := scs.server.GetStore().Channel().Get(invite.ChannelId, true) + if err != nil { + if channel, err = scs.handleChannelCreation(invite, rc); err != nil { + return err + } + } + + if invite.ReadOnly { + if err := scs.makeChannelReadOnly(channel); err != nil { + return fmt.Errorf("cannot make channel readonly `%s`: %w", invite.ChannelId, err) + } + } + + sharedChannel := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + Home: false, + ReadOnly: invite.ReadOnly, + ShareName: channel.Name, + ShareDisplayName: channel.DisplayName, + SharePurpose: channel.Purpose, + ShareHeader: channel.Header, + CreatorId: rc.CreatorId, + RemoteId: rc.RemoteId, + Type: channel.Type, + } + + if _, err := scs.server.GetStore().SharedChannel().Save(sharedChannel); err != nil { + scs.app.PermanentDeleteChannel(channel) + return fmt.Errorf("cannot create shared channel (channel_id=%s): %w", invite.ChannelId, err) + } + + sharedChannelRemote := &model.SharedChannelRemote{ + Id: model.NewId(), + ChannelId: channel.Id, + Description: invite.DisplayName, + CreatorId: channel.CreatorId, + IsInviteAccepted: true, + IsInviteConfirmed: true, + RemoteId: rc.RemoteId, + } + + if _, err := scs.server.GetStore().SharedChannel().SaveRemote(sharedChannelRemote); err != nil { + scs.app.PermanentDeleteChannel(channel) + scs.server.GetStore().SharedChannel().Delete(sharedChannel.ChannelId) + return fmt.Errorf("cannot create shared channel remote (channel_id=%s): %w", invite.ChannelId, err) + } + return nil +} + +func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) { + if invite.Type == model.CHANNEL_DIRECT { + return scs.createDirectChannel(invite) + } + + channelNew := &model.Channel{ + Id: invite.ChannelId, + TeamId: invite.TeamId, + Type: invite.Type, + DisplayName: invite.DisplayName, + Name: invite.Name, + Header: invite.Header, + Purpose: invite.Purpose, + CreatorId: rc.CreatorId, + Shared: model.NewBool(true), + } + + // check user perms? + channel, appErr := scs.app.CreateChannelWithUser(channelNew, rc.CreatorId) + if appErr != nil { + return nil, fmt.Errorf("cannot create channel `%s`: %w", invite.ChannelId, appErr) + } + + return channel, nil +} + +func (scs *Service) createDirectChannel(invite channelInviteMsg) (*model.Channel, error) { + if len(invite.DirectParticipantIDs) != 2 { + return nil, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs)) + } + + channel, err := scs.app.GetOrCreateDirectChannel(invite.DirectParticipantIDs[0], invite.DirectParticipantIDs[1], model.WithID(invite.ChannelId)) + if err != nil { + return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, err) + } + + return channel, nil +} diff --git a/services/sharedchannel/channelinvite_test.go b/services/sharedchannel/channelinvite_test.go new file mode 100644 index 00000000000..a2929744e98 --- /dev/null +++ b/services/sharedchannel/channelinvite_test.go @@ -0,0 +1,197 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/plugin/plugintest/mock" + "github.com/mattermost/mattermost-server/v5/shared/mlog" + "github.com/mattermost/mattermost-server/v5/store/storetest/mocks" +) + +type mockLogger struct { + mlog.LoggerIFace +} + +func (ml *mockLogger) Log(level mlog.LogLevel, s string, flds ...mlog.Field) {} + +func TestOnReceiveChannelInvite(t *testing.T) { + t.Run("when msg payload is empty, it does nothing", func(t *testing.T) { + mockServer := &MockServerIface{} + mockLogger := &mockLogger{} + mockServer.On("GetLogger").Return(mockLogger) + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + mockServer = scs.server.(*MockServerIface) + mockServer.On("GetStore").Return(mockStore) + + remoteCluster := &model.RemoteCluster{} + msg := model.RemoteClusterMsg{} + + err := scs.onReceiveChannelInvite(msg, remoteCluster, nil) + require.NoError(t, err) + mockStore.AssertNotCalled(t, "Channel") + }) + + t.Run("when invitation prescribes a readonly channel, it does create a readonly channel", func(t *testing.T) { + mockServer := &MockServerIface{} + mockLogger := &mockLogger{} + mockServer.On("GetLogger").Return(mockLogger) + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + remoteCluster := &model.RemoteCluster{DisplayName: "test"} + invitation := channelInviteMsg{ + ChannelId: model.NewId(), + TeamId: model.NewId(), + ReadOnly: true, + Type: "0", + } + payload, err := json.Marshal(invitation) + require.NoError(t, err) + + msg := model.RemoteClusterMsg{ + Payload: payload, + } + mockChannelStore := mocks.ChannelStore{} + mockSharedChannelStore := mocks.SharedChannelStore{} + channel := &model.Channel{} + + mockChannelStore.On("Get", invitation.ChannelId, true).Return(channel, nil) + mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil) + mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil) + mockStore.On("Channel").Return(&mockChannelStore) + mockStore.On("SharedChannel").Return(&mockSharedChannelStore) + + mockServer = scs.server.(*MockServerIface) + mockServer.On("GetStore").Return(mockStore) + createPostPermission := model.ChannelModeratedPermissionsMap[model.PERMISSION_CREATE_POST.Id] + createReactionPermission := model.ChannelModeratedPermissionsMap[model.PERMISSION_ADD_REACTION.Id] + updateMap := model.ChannelModeratedRolesPatch{ + Guests: model.NewBool(false), + Members: model.NewBool(false), + } + + readonlyChannelModerations := []*model.ChannelModerationPatch{ + { + Name: &createPostPermission, + Roles: &updateMap, + }, + { + Name: &createReactionPermission, + Roles: &updateMap, + }, + } + mockApp.On("PatchChannelModerationsForChannel", channel, readonlyChannelModerations).Return(nil, nil) + defer mockApp.AssertExpectations(t) + + err = scs.onReceiveChannelInvite(msg, remoteCluster, nil) + require.NoError(t, err) + }) + + t.Run("when invitation prescribes a readonly channel and readonly update fails, it returns an error", func(t *testing.T) { + mockServer := &MockServerIface{} + mockLogger := &mockLogger{} + mockServer.On("GetLogger").Return(mockLogger) + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + remoteCluster := &model.RemoteCluster{DisplayName: "test"} + invitation := channelInviteMsg{ + ChannelId: model.NewId(), + TeamId: model.NewId(), + ReadOnly: true, + Type: "0", + } + payload, err := json.Marshal(invitation) + require.NoError(t, err) + + msg := model.RemoteClusterMsg{ + Payload: payload, + } + mockChannelStore := mocks.ChannelStore{} + channel := &model.Channel{} + + mockChannelStore.On("Get", invitation.ChannelId, true).Return(channel, nil) + mockStore.On("Channel").Return(&mockChannelStore) + + mockServer = scs.server.(*MockServerIface) + mockServer.On("GetStore").Return(mockStore) + appErr := model.NewAppError("foo", "bar", nil, "boom", http.StatusBadRequest) + + mockApp.On("PatchChannelModerationsForChannel", channel, mock.Anything).Return(nil, appErr) + defer mockApp.AssertExpectations(t) + + err = scs.onReceiveChannelInvite(msg, remoteCluster, nil) + require.Error(t, err) + assert.Equal(t, fmt.Sprintf("cannot make channel readonly `%s`: foo: bar, boom", invitation.ChannelId), err.Error()) + }) + + t.Run("when invitation prescribes a direct channel, it does create a direct channel", func(t *testing.T) { + mockServer := &MockServerIface{} + mockLogger := &mockLogger{} + mockServer.On("GetLogger").Return(mockLogger) + mockApp := &MockAppIface{} + scs := &Service{ + server: mockServer, + app: mockApp, + } + + mockStore := &mocks.Store{} + remoteCluster := &model.RemoteCluster{DisplayName: "test", CreatorId: model.NewId()} + invitation := channelInviteMsg{ + ChannelId: model.NewId(), + TeamId: model.NewId(), + ReadOnly: false, + Type: model.CHANNEL_DIRECT, + DirectParticipantIDs: []string{model.NewId(), model.NewId()}, + } + payload, err := json.Marshal(invitation) + require.NoError(t, err) + + msg := model.RemoteClusterMsg{ + Payload: payload, + } + mockChannelStore := mocks.ChannelStore{} + mockSharedChannelStore := mocks.SharedChannelStore{} + channel := &model.Channel{} + + mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom")) + mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil) + mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil) + mockStore.On("Channel").Return(&mockChannelStore) + mockStore.On("SharedChannel").Return(&mockSharedChannelStore) + + mockServer = scs.server.(*MockServerIface) + mockServer.On("GetStore").Return(mockStore) + + mockApp.On("GetOrCreateDirectChannel", invitation.DirectParticipantIDs[0], invitation.DirectParticipantIDs[1], mock.AnythingOfType("model.ChannelOption")).Return(channel, nil) + defer mockApp.AssertExpectations(t) + + err = scs.onReceiveChannelInvite(msg, remoteCluster, nil) + require.NoError(t, err) + }) +} diff --git a/services/sharedchannel/getpostssince.go b/services/sharedchannel/getpostssince.go new file mode 100644 index 00000000000..26d36977fd3 --- /dev/null +++ b/services/sharedchannel/getpostssince.go @@ -0,0 +1,83 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +type sinceResult struct { + posts []*model.Post + hasMore bool + nextSince int64 +} + +// getPostsSince fetches posts that need to be synchronized with a remote cluster. +// There is a soft cap on the number of posts that will be synchronized in a single pass (MaxPostsPerSync). +// +// There is a special case where multiple posts have the same UpdateAt value. It is vital that this method +// include all posts within that millisecond so that subsequent calls can use an incremented `since`. If this +// method were to be called repeatedly with the same `since` value the same records would be returned each time +// and the sync would never move forward. +// +// A boolean is also returned to indicate if there are more posts to be synchronized (true) or not (false). +func (scs *Service) getPostsSince(channelId string, rc *model.RemoteCluster, since int64) (sinceResult, error) { + opts := model.GetPostsSinceForSyncOptions{ + ChannelId: channelId, + Since: since, + IncludeDeleted: true, + Limit: MaxPostsPerSync + 1, // ask for 1 more than needed to peek at first post in next batch + } + posts, err := scs.server.GetStore().Post().GetPostsSinceForSync(opts, true) + if err != nil { + return sinceResult{}, err + } + + if len(posts) == 0 { + return sinceResult{nextSince: since}, nil + } + + var hasMore bool + if len(posts) > MaxPostsPerSync { + hasMore = true + peekUpdateAt := posts[len(posts)-1].UpdateAt + posts = posts[:MaxPostsPerSync] // trim the peeked at record + + // If the last post to be synchronized has the same Update value as the first post in the next batch + // then we need to grab the rest of the posts for that millisecond to ensure the next call can have an + // incremented `since`. + if peekUpdateAt == posts[len(posts)-1].UpdateAt { + opts.Since = peekUpdateAt + opts.Until = opts.Since + opts.Limit = 1000 + opts.Offset = countPostsAtMillisecond(posts, peekUpdateAt) + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "getPostsSince handling updateAt collision", + mlog.String("remote", rc.DisplayName), + mlog.Int64("update_at", peekUpdateAt), + mlog.Int("offset", opts.Offset), + ) + + morePosts, err := scs.server.GetStore().Post().GetPostsSinceForSync(opts, true) + if err != nil { + return sinceResult{}, err + } + posts = append(posts, morePosts...) + } + } + return sinceResult{posts: posts, hasMore: hasMore, nextSince: posts[len(posts)-1].UpdateAt + 1}, nil +} + +func countPostsAtMillisecond(posts []*model.Post, milli int64) int { + // walk backward through the slice until we find a post with UpdateAt that differs from milli. + var count int + for i := len(posts) - 1; i >= 0; i-- { + if posts[i].UpdateAt != milli { + return count + } + count++ + } + return count +} diff --git a/services/sharedchannel/mock_AppIface_test.go b/services/sharedchannel/mock_AppIface_test.go new file mode 100644 index 00000000000..645a96e2d5f --- /dev/null +++ b/services/sharedchannel/mock_AppIface_test.go @@ -0,0 +1,338 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +// Regenerate this file using `make sharedchannel-mocks`. + +package sharedchannel + +import ( + filestore "github.com/mattermost/mattermost-server/v5/shared/filestore" + mock "github.com/stretchr/testify/mock" + + model "github.com/mattermost/mattermost-server/v5/model" +) + +// MockAppIface is an autogenerated mock type for the AppIface type +type MockAppIface struct { + mock.Mock +} + +// AddUserToChannel provides a mock function with given fields: user, channel +func (_m *MockAppIface) AddUserToChannel(user *model.User, channel *model.Channel) (*model.ChannelMember, *model.AppError) { + ret := _m.Called(user, channel) + + var r0 *model.ChannelMember + if rf, ok := ret.Get(0).(func(*model.User, *model.Channel) *model.ChannelMember); ok { + r0 = rf(user, channel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelMember) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.User, *model.Channel) *model.AppError); ok { + r1 = rf(user, channel) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// AddUserToTeamByTeamId provides a mock function with given fields: teamId, user +func (_m *MockAppIface) AddUserToTeamByTeamId(teamId string, user *model.User) *model.AppError { + ret := _m.Called(teamId, user) + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(string, *model.User) *model.AppError); ok { + r0 = rf(teamId, user) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + +// CreateChannelWithUser provides a mock function with given fields: channel, userId +func (_m *MockAppIface) CreateChannelWithUser(channel *model.Channel, userId string) (*model.Channel, *model.AppError) { + ret := _m.Called(channel, userId) + + var r0 *model.Channel + if rf, ok := ret.Get(0).(func(*model.Channel, string) *model.Channel); ok { + r0 = rf(channel, userId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Channel) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Channel, string) *model.AppError); ok { + r1 = rf(channel, userId) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// CreatePost provides a mock function with given fields: post, channel, triggerWebhooks, setOnline +func (_m *MockAppIface) CreatePost(post *model.Post, channel *model.Channel, triggerWebhooks bool, setOnline bool) (*model.Post, *model.AppError) { + ret := _m.Called(post, channel, triggerWebhooks, setOnline) + + var r0 *model.Post + if rf, ok := ret.Get(0).(func(*model.Post, *model.Channel, bool, bool) *model.Post); ok { + r0 = rf(post, channel, triggerWebhooks, setOnline) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Post, *model.Channel, bool, bool) *model.AppError); ok { + r1 = rf(post, channel, triggerWebhooks, setOnline) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// CreateUploadSession provides a mock function with given fields: us +func (_m *MockAppIface) CreateUploadSession(us *model.UploadSession) (*model.UploadSession, *model.AppError) { + ret := _m.Called(us) + + var r0 *model.UploadSession + if rf, ok := ret.Get(0).(func(*model.UploadSession) *model.UploadSession); ok { + r0 = rf(us) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.UploadSession) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.UploadSession) *model.AppError); ok { + r1 = rf(us) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// DeletePost provides a mock function with given fields: postID, deleteByID +func (_m *MockAppIface) DeletePost(postID string, deleteByID string) (*model.Post, *model.AppError) { + ret := _m.Called(postID, deleteByID) + + var r0 *model.Post + if rf, ok := ret.Get(0).(func(string, string) *model.Post); ok { + r0 = rf(postID, deleteByID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string, string) *model.AppError); ok { + r1 = rf(postID, deleteByID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// DeleteReactionForPost provides a mock function with given fields: reaction +func (_m *MockAppIface) DeleteReactionForPost(reaction *model.Reaction) *model.AppError { + ret := _m.Called(reaction) + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(*model.Reaction) *model.AppError); ok { + r0 = rf(reaction) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + +// FileReader provides a mock function with given fields: path +func (_m *MockAppIface) FileReader(path string) (filestore.ReadCloseSeeker, *model.AppError) { + ret := _m.Called(path) + + var r0 filestore.ReadCloseSeeker + if rf, ok := ret.Get(0).(func(string) filestore.ReadCloseSeeker); ok { + r0 = rf(path) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(filestore.ReadCloseSeeker) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string) *model.AppError); ok { + r1 = rf(path) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// GetOrCreateDirectChannel provides a mock function with given fields: userId, otherUserId, channelOptions +func (_m *MockAppIface) GetOrCreateDirectChannel(userId string, otherUserId string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) { + _va := make([]interface{}, len(channelOptions)) + for _i := range channelOptions { + _va[_i] = channelOptions[_i] + } + var _ca []interface{} + _ca = append(_ca, userId, otherUserId) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *model.Channel + if rf, ok := ret.Get(0).(func(string, string, ...model.ChannelOption) *model.Channel); ok { + r0 = rf(userId, otherUserId, channelOptions...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Channel) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string, string, ...model.ChannelOption) *model.AppError); ok { + r1 = rf(userId, otherUserId, channelOptions...) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// PatchChannelModerationsForChannel provides a mock function with given fields: channel, channelModerationsPatch +func (_m *MockAppIface) PatchChannelModerationsForChannel(channel *model.Channel, channelModerationsPatch []*model.ChannelModerationPatch) ([]*model.ChannelModeration, *model.AppError) { + ret := _m.Called(channel, channelModerationsPatch) + + var r0 []*model.ChannelModeration + if rf, ok := ret.Get(0).(func(*model.Channel, []*model.ChannelModerationPatch) []*model.ChannelModeration); ok { + r0 = rf(channel, channelModerationsPatch) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelModeration) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Channel, []*model.ChannelModerationPatch) *model.AppError); ok { + r1 = rf(channel, channelModerationsPatch) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// PermanentDeleteChannel provides a mock function with given fields: channel +func (_m *MockAppIface) PermanentDeleteChannel(channel *model.Channel) *model.AppError { + ret := _m.Called(channel) + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(*model.Channel) *model.AppError); ok { + r0 = rf(channel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + +// SaveReactionForPost provides a mock function with given fields: reaction +func (_m *MockAppIface) SaveReactionForPost(reaction *model.Reaction) (*model.Reaction, *model.AppError) { + ret := _m.Called(reaction) + + var r0 *model.Reaction + if rf, ok := ret.Get(0).(func(*model.Reaction) *model.Reaction); ok { + r0 = rf(reaction) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Reaction) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Reaction) *model.AppError); ok { + r1 = rf(reaction) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + +// SendEphemeralPost provides a mock function with given fields: userId, post +func (_m *MockAppIface) SendEphemeralPost(userId string, post *model.Post) *model.Post { + ret := _m.Called(userId, post) + + var r0 *model.Post + if rf, ok := ret.Get(0).(func(string, *model.Post) *model.Post); ok { + r0 = rf(userId, post) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + } + + return r0 +} + +// UpdatePost provides a mock function with given fields: post, safeUpdate +func (_m *MockAppIface) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model.AppError) { + ret := _m.Called(post, safeUpdate) + + var r0 *model.Post + if rf, ok := ret.Get(0).(func(*model.Post, bool) *model.Post); ok { + r0 = rf(post, safeUpdate) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Post, bool) *model.AppError); ok { + r1 = rf(post, safeUpdate) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} diff --git a/services/sharedchannel/mock_ServerIface_test.go b/services/sharedchannel/mock_ServerIface_test.go new file mode 100644 index 00000000000..a33159c116d --- /dev/null +++ b/services/sharedchannel/mock_ServerIface_test.go @@ -0,0 +1,118 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +// Regenerate this file using `make sharedchannel-mocks`. + +package sharedchannel + +import ( + mlog "github.com/mattermost/mattermost-server/v5/shared/mlog" + mock "github.com/stretchr/testify/mock" + + model "github.com/mattermost/mattermost-server/v5/model" + + remotecluster "github.com/mattermost/mattermost-server/v5/services/remotecluster" + + store "github.com/mattermost/mattermost-server/v5/store" +) + +// MockServerIface is an autogenerated mock type for the ServerIface type +type MockServerIface struct { + mock.Mock +} + +// AddClusterLeaderChangedListener provides a mock function with given fields: listener +func (_m *MockServerIface) AddClusterLeaderChangedListener(listener func()) string { + ret := _m.Called(listener) + + var r0 string + if rf, ok := ret.Get(0).(func(func()) string); ok { + r0 = rf(listener) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Config provides a mock function with given fields: +func (_m *MockServerIface) Config() *model.Config { + ret := _m.Called() + + var r0 *model.Config + if rf, ok := ret.Get(0).(func() *model.Config); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Config) + } + } + + return r0 +} + +// GetLogger provides a mock function with given fields: +func (_m *MockServerIface) GetLogger() mlog.LoggerIFace { + ret := _m.Called() + + var r0 mlog.LoggerIFace + if rf, ok := ret.Get(0).(func() mlog.LoggerIFace); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(mlog.LoggerIFace) + } + } + + return r0 +} + +// GetRemoteClusterService provides a mock function with given fields: +func (_m *MockServerIface) GetRemoteClusterService() remotecluster.RemoteClusterServiceIFace { + ret := _m.Called() + + var r0 remotecluster.RemoteClusterServiceIFace + if rf, ok := ret.Get(0).(func() remotecluster.RemoteClusterServiceIFace); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(remotecluster.RemoteClusterServiceIFace) + } + } + + return r0 +} + +// GetStore provides a mock function with given fields: +func (_m *MockServerIface) GetStore() store.Store { + ret := _m.Called() + + var r0 store.Store + if rf, ok := ret.Get(0).(func() store.Store); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.Store) + } + } + + return r0 +} + +// IsLeader provides a mock function with given fields: +func (_m *MockServerIface) IsLeader() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// RemoveClusterLeaderChangedListener provides a mock function with given fields: id +func (_m *MockServerIface) RemoveClusterLeaderChangedListener(id string) { + _m.Called(id) +} diff --git a/services/sharedchannel/msg.go b/services/sharedchannel/msg.go new file mode 100644 index 00000000000..09917d4e3b1 --- /dev/null +++ b/services/sharedchannel/msg.go @@ -0,0 +1,216 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +// syncMsg represents a change in content (post add/edit/delete, reaction add/remove, users). +// It is sent to remote clusters as the payload of a `RemoteClusterMsg`. +type syncMsg struct { + ChannelId string `json:"channel_id"` + PostId string `json:"post_id"` + Post *model.Post `json:"post"` + Users []*model.User `json:"users"` + Reactions []*model.Reaction `json:"reactions"` + Attachments []*model.FileInfo `json:"-"` +} + +func (sm syncMsg) ToJSON() ([]byte, error) { + b, err := json.Marshal(sm) + if err != nil { + return nil, err + } + return b, nil +} + +func (sm syncMsg) String() string { + json, err := sm.ToJSON() + if err != nil { + return "" + } + return string(json) +} + +type userCache map[string]struct{} + +func (u userCache) Has(id string) bool { + _, ok := u[id] + return ok +} + +func (u userCache) Add(id string) { + u[id] = struct{}{} +} + +// postsToSyncMessages takes a slice of posts and converts to a `RemoteClusterMsg` which can be +// sent to a remote cluster. +func (scs *Service) postsToSyncMessages(posts []*model.Post, rc *model.RemoteCluster, nextSyncAt int64) ([]syncMsg, error) { + syncMessages := make([]syncMsg, 0, len(posts)) + + uCache := make(userCache) + + for _, p := range posts { + if p.IsSystemMessage() { // don't sync system messages + continue + } + + // any reactions originating from the remote cluster are filtered out + reactions, err := scs.server.GetStore().Reaction().GetForPostSince(p.Id, nextSyncAt, rc.RemoteId, true) + if err != nil { + return nil, err + } + + postSync := p + + // Don't resend an existing post where only the reactions changed. + // Posts we must send: + // - new posts (EditAt == 0) + // - edited posts (EditAt >= nextSyncAt) + // - deleted posts (DeleteAt > 0) + if p.EditAt > 0 && p.EditAt < nextSyncAt && p.DeleteAt == 0 { + postSync = nil + } + + // Don't send a deleted post if it is just the original copy from an edit. + if p.DeleteAt > 0 && p.OriginalId != "" { + postSync = nil + } + + // don't sync a post back to the remote it came from. + if p.RemoteId != nil && *p.RemoteId == rc.RemoteId { + postSync = nil + } + + var attachments []*model.FileInfo + if postSync != nil { + // parse out all permalinks in the message. + postSync.Message = scs.processPermalinkToRemote(postSync) + + // get any file attachments + attachments, err = scs.postToAttachments(postSync, rc) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Could not fetch attachments for post", + mlog.String("post_id", postSync.Id), + mlog.Err(err), + ) + } + } + + // any users originating from the remote cluster are filtered out + users := scs.usersForPost(postSync, reactions, rc, uCache) + + // if everything was filtered out then don't send an empty message. + if postSync == nil && len(reactions) == 0 && len(users) == 0 { + continue + } + + sm := syncMsg{ + ChannelId: p.ChannelId, + PostId: p.Id, + Post: postSync, + Users: users, + Reactions: reactions, + Attachments: attachments, + } + syncMessages = append(syncMessages, sm) + } + return syncMessages, nil +} + +// usersForPost provides a list of Users associated with the post that need to be synchronized. +// The user cache ensures the same user is not synchronized redundantly if they appear in multiple +// posts for this sync batch. +func (scs *Service) usersForPost(post *model.Post, reactions []*model.Reaction, rc *model.RemoteCluster, uCache userCache) []*model.User { + userIds := make([]string, 0) + + if post != nil && !uCache.Has(post.UserId) { + userIds = append(userIds, post.UserId) + uCache.Add(post.UserId) + } + + for _, r := range reactions { + if !uCache.Has(r.UserId) { + userIds = append(userIds, r.UserId) + uCache.Add(r.UserId) + } + } + + // TODO: extract @mentions to local users and sync those as well? + + users := make([]*model.User, 0) + + for _, id := range userIds { + user, err := scs.server.GetStore().User().Get(context.Background(), id) + if err == nil { + if sync, err2 := scs.shouldUserSync(user, rc); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Could not find user for post", + mlog.String("user_id", id), + mlog.Err(err2)) + continue + } else if sync { + users = append(users, sanitizeUserForSync(user)) + } + } else { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error checking if user should sync", + mlog.String("user_id", id), + mlog.Err(err)) + } + } + return users +} + +func sanitizeUserForSync(user *model.User) *model.User { + user.Password = model.NewId() + user.AuthData = nil + user.AuthService = "" + user.Roles = "system_user" + user.AllowMarketing = false + user.Props = model.StringMap{} + user.NotifyProps = model.StringMap{} + user.LastPasswordUpdate = 0 + user.LastPictureUpdate = 0 + user.FailedAttempts = 0 + user.MfaActive = false + user.MfaSecret = "" + + return user +} + +// shouldUserSync determines if a user needs to be synchronized. +// User should be synchronized if it has no entry in the SharedChannelUsers table, +// or there is an entry but the LastSyncAt is less than user.UpdateAt +func (scs *Service) shouldUserSync(user *model.User, rc *model.RemoteCluster) (bool, error) { + // don't sync users with the remote they originated from. + if user.RemoteId != nil && *user.RemoteId == rc.RemoteId { + return false, nil + } + + scu, err := scs.server.GetStore().SharedChannel().GetUser(user.Id, rc.RemoteId) + if err != nil { + if _, ok := err.(errNotFound); !ok { + return false, err + } + + // user not in the SharedChannelUsers table, so we must add them. + scu = &model.SharedChannelUser{ + UserId: user.Id, + RemoteId: rc.RemoteId, + } + if _, err = scs.server.GetStore().SharedChannel().SaveUser(scu); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error adding user to shared channel users", + mlog.String("remote_id", rc.RemoteId), + mlog.String("user_id", user.Id), + ) + } + } else if scu.LastSyncAt >= user.UpdateAt { + return false, nil + } + return true, nil +} diff --git a/services/sharedchannel/permalink.go b/services/sharedchannel/permalink.go new file mode 100644 index 00000000000..163f077fcc2 --- /dev/null +++ b/services/sharedchannel/permalink.go @@ -0,0 +1,81 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "net/url" + "regexp" + "strings" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/shared/i18n" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +var ( + // Team name regex taken from model.IsValidTeamName + permaLinkRegex = regexp.MustCompile(`https?://[0-9.\-A-Za-z]+/[a-z0-9]+([a-z\-0-9]+|(__)?)[a-z0-9]+/pl/([a-zA-Z0-9]+)`) + permaLinkSharedRegex = regexp.MustCompile(`https?://[0-9.\-A-Za-z]+/[a-z0-9]+([a-z\-0-9]+|(__)?)[a-z0-9]+/plshared/([a-zA-Z0-9]+)`) +) + +const ( + permalinkMarker = "plshared" +) + +// processPermalinkToRemote processes all permalinks going towards a remote site. +func (scs *Service) processPermalinkToRemote(p *model.Post) string { + var sent bool + return permaLinkRegex.ReplaceAllStringFunc(p.Message, func(msg string) string { + // Extract the postID (This is simple enough not to warrant full-blown URL parsing.) + lastSlash := strings.LastIndexByte(msg, '/') + postID := msg[lastSlash+1:] + postList, err := scs.server.GetStore().Post().Get(context.Background(), postID, true, false, false, "") + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceWarn, "Unable to get post during replacing permalinks", mlog.Err(err)) + return msg + } + if len(postList.Order) == 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceWarn, "No post found for permalink", mlog.String("postID", postID)) + return msg + } + + // If postID is for a different channel + if postList.Posts[postList.Order[0]].ChannelId != p.ChannelId { + // Send ephemeral message to OP (only once per message). + if !sent { + scs.sendEphemeralPost(p.ChannelId, p.UserId, i18n.T("sharedchannel.permalink.not_found")) + sent = true + } + // But don't modify msg + return msg + } + + // Otherwise, modify pl to plshared as a marker to be replaced by remote sites + return strings.Replace(msg, "/pl/", "/"+permalinkMarker+"/", 1) + }) +} + +// processPermalinkFromRemote processes all permalinks coming from a remote site. +func (scs *Service) processPermalinkFromRemote(p *model.Post, team *model.Team) string { + return permaLinkSharedRegex.ReplaceAllStringFunc(p.Message, func(remoteLink string) string { + // Extract host name + parsed, err := url.Parse(remoteLink) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceWarn, "Unable to parse the remote link during replacing permalinks", mlog.Err(err)) + return remoteLink + } + + // Replace with local SiteURL + parsed.Scheme = scs.siteURL.Scheme + parsed.Host = scs.siteURL.Host + + // Replace team name with local team + teamEnd := strings.Index(parsed.Path, "/"+permalinkMarker) + parsed.Path = "/" + team.Name + parsed.Path[teamEnd:] + + // Replace plshared with pl + return strings.Replace(parsed.String(), "/"+permalinkMarker+"/", "/pl/", 1) + }) +} diff --git a/services/sharedchannel/permalink_test.go b/services/sharedchannel/permalink_test.go new file mode 100644 index 00000000000..95e3d90af5f --- /dev/null +++ b/services/sharedchannel/permalink_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/plugin/plugintest/mock" + "github.com/mattermost/mattermost-server/v5/store/storetest/mocks" + "github.com/mattermost/mattermost-server/v5/utils" +) + +func TestProcessPermalinkToRemote(t *testing.T) { + scs := &Service{ + server: &MockServerIface{}, + app: &MockAppIface{}, + } + + mockStore := &mocks.Store{} + mockPostStore := mocks.PostStore{} + utils.TranslationsPreInit() + + pl := &model.PostList{} + mockPostStore.On("Get", context.Background(), "postID", true, false, false, "").Return(pl, nil) + + mockStore.On("Post").Return(&mockPostStore) + + mockServer := scs.server.(*MockServerIface) + mockServer.On("GetStore").Return(mockStore) + + mockApp := scs.app.(*MockAppIface) + mockApp.On("SendEphemeralPost", "user", mock.AnythingOfType("*model.Post")).Return(&model.Post{}).Times(1) + defer mockApp.AssertExpectations(t) + + t.Run("same channel", func(t *testing.T) { + post := &model.Post{ + Message: "hello world https://comm.matt.com/team/pl/postID link", + ChannelId: "sourceChan", + UserId: "user", + } + + *pl = model.PostList{ + Order: []string{"1"}, + Posts: map[string]*model.Post{ + "1": { + ChannelId: "sourceChan", + UserId: "user", + }, + }, + } + + out := scs.processPermalinkToRemote(post) + assert.Equal(t, "hello world https://comm.matt.com/team/plshared/postID link", out) + }) + + t.Run("different channel", func(t *testing.T) { + post := &model.Post{ + Message: "hello world https://comm.matt.com/team/pl/postID link https://comm.matt.com/team/pl/postID ", + ChannelId: "sourceChan", + UserId: "user", + } + + *pl = model.PostList{ + Order: []string{"1"}, + Posts: map[string]*model.Post{ + "1": { + ChannelId: "otherChan", + }, + }, + } + out := scs.processPermalinkToRemote(post) + assert.Equal(t, "hello world https://comm.matt.com/team/pl/postID link https://comm.matt.com/team/pl/postID ", out) + }) +} + +func TestProcessPermalinkFromRemote(t *testing.T) { + t.Run("has match", func(t *testing.T) { + parsed, _ := url.Parse("http://mysite.com") + scs := &Service{ + server: &MockServerIface{}, + siteURL: parsed, + } + + out := scs.processPermalinkFromRemote(&model.Post{Message: "hello world https://comm.matt.com/team/plshared/postID link"}, + &model.Team{Name: "myteam"}) + assert.Equal(t, + "hello world http://mysite.com/myteam/pl/postID link", + out) + }) + + t.Run("does not match", func(t *testing.T) { + parsed, _ := url.Parse("http://mysite.com") + scs := &Service{ + server: &MockServerIface{}, + siteURL: parsed, + } + + out := scs.processPermalinkFromRemote(&model.Post{Message: "hello world https://comm.matt.com/team/pl/postID link"}, + &model.Team{Name: "myteam"}) + assert.Equal(t, + "hello world https://comm.matt.com/team/pl/postID link", + out) + }) +} diff --git a/services/sharedchannel/response.go b/services/sharedchannel/response.go new file mode 100644 index 00000000000..fc85ae2cfa2 --- /dev/null +++ b/services/sharedchannel/response.go @@ -0,0 +1,10 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +type SyncResponse struct { + LastSyncAt int64 `json:"last_sync_at"` + PostErrors []string `json:"post_errors"` + UsersSyncd []string `json:"users_syncd"` +} diff --git a/services/sharedchannel/service.go b/services/sharedchannel/service.go new file mode 100644 index 00000000000..9bb02edddb8 --- /dev/null +++ b/services/sharedchannel/service.go @@ -0,0 +1,239 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "errors" + "fmt" + "net/url" + "sync" + "time" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/filestore" + "github.com/mattermost/mattermost-server/v5/shared/mlog" + "github.com/mattermost/mattermost-server/v5/store" +) + +const ( + TopicSync = "sharedchannel_sync" + TopicChannelInvite = "sharedchannel_invite" + TopicUploadCreate = "sharedchannel_upload" + MaxRetries = 3 + MaxPostsPerSync = 12 // a bit more than one typical screenfull of posts + NotifyRemoteOfflineThreshold = time.Second * 10 + NotifyMinimumDelay = time.Second * 2 +) + +// Mocks can be re-generated with `make sharedchannel-mocks`. +type ServerIface interface { + Config() *model.Config + IsLeader() bool + AddClusterLeaderChangedListener(listener func()) string + RemoveClusterLeaderChangedListener(id string) + GetStore() store.Store + GetLogger() mlog.LoggerIFace + GetRemoteClusterService() remotecluster.RemoteClusterServiceIFace +} + +type AppIface interface { + SendEphemeralPost(userId string, post *model.Post) *model.Post + CreateChannelWithUser(channel *model.Channel, userId string) (*model.Channel, *model.AppError) + GetOrCreateDirectChannel(userId, otherUserId string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError) + AddUserToChannel(user *model.User, channel *model.Channel) (*model.ChannelMember, *model.AppError) + AddUserToTeamByTeamId(teamId string, user *model.User) *model.AppError + PermanentDeleteChannel(channel *model.Channel) *model.AppError + CreatePost(post *model.Post, channel *model.Channel, triggerWebhooks bool, setOnline bool) (savedPost *model.Post, err *model.AppError) + UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model.AppError) + DeletePost(postID, deleteByID string) (*model.Post, *model.AppError) + SaveReactionForPost(reaction *model.Reaction) (*model.Reaction, *model.AppError) + DeleteReactionForPost(reaction *model.Reaction) *model.AppError + PatchChannelModerationsForChannel(channel *model.Channel, channelModerationsPatch []*model.ChannelModerationPatch) ([]*model.ChannelModeration, *model.AppError) + CreateUploadSession(us *model.UploadSession) (*model.UploadSession, *model.AppError) + FileReader(path string) (filestore.ReadCloseSeeker, *model.AppError) +} + +// errNotFound allows checking against Store.ErrNotFound errors without making Store a dependency. +type errNotFound interface { + IsErrNotFound() bool +} + +// errInvalidInput allows checking against Store.ErrInvalidInput errors without making Store a dependency. +type errInvalidInput interface { + InvalidInputInfo() (entity string, field string, value interface{}) +} + +// Service provides shared channel synchronization. +type Service struct { + server ServerIface + app AppIface + changeSignal chan struct{} + + // everything below guarded by `mux` + mux sync.RWMutex + active bool + leaderListenerId string + connectionStateListenerId string + done chan struct{} + tasks map[string]syncTask + syncTopicListenerId string + inviteTopicListenerId string + uploadTopicListenerId string + siteURL *url.URL +} + +// NewSharedChannelService creates a RemoteClusterService instance. +func NewSharedChannelService(server ServerIface, app AppIface) (*Service, error) { + service := &Service{ + server: server, + app: app, + changeSignal: make(chan struct{}, 1), + tasks: make(map[string]syncTask), + } + parsed, err := url.Parse(*server.Config().ServiceSettings.SiteURL) + if err != nil { + return nil, fmt.Errorf("unable to parse SiteURL: %w", err) + } + service.siteURL = parsed + return service, nil +} + +// Start is called by the server on server start-up. +func (scs *Service) Start() error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return errors.New("Shared Channel Service cannot activate: requires Remote Cluster Service") + } + + scs.mux.Lock() + scs.leaderListenerId = scs.server.AddClusterLeaderChangedListener(scs.onClusterLeaderChange) + scs.syncTopicListenerId = rcs.AddTopicListener(TopicSync, scs.onReceiveSyncMessage) + scs.inviteTopicListenerId = rcs.AddTopicListener(TopicChannelInvite, scs.onReceiveChannelInvite) + scs.uploadTopicListenerId = rcs.AddTopicListener(TopicUploadCreate, scs.onReceiveUploadCreate) + scs.connectionStateListenerId = rcs.AddConnectionStateListener(scs.onConnectionStateChange) + scs.mux.Unlock() + + scs.onClusterLeaderChange() + + return nil +} + +// Shutdown is called by the server on server shutdown. +func (scs *Service) Shutdown() error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return errors.New("Shared Channel Service cannot shutdown: requires Remote Cluster Service") + } + + scs.mux.Lock() + id := scs.leaderListenerId + rcs.RemoveTopicListener(scs.syncTopicListenerId) + scs.syncTopicListenerId = "" + rcs.RemoveTopicListener(scs.inviteTopicListenerId) + scs.inviteTopicListenerId = "" + rcs.RemoveConnectionStateListener(scs.connectionStateListenerId) + scs.connectionStateListenerId = "" + scs.mux.Unlock() + + scs.server.RemoveClusterLeaderChangedListener(id) + scs.pause() + return nil +} + +// Active determines whether the service is active on the node or not. +func (scs *Service) Active() bool { + scs.mux.Lock() + defer scs.mux.Unlock() + + return scs.active +} + +func (scs *Service) sendEphemeralPost(channelId string, userId string, text string) { + ephemeral := &model.Post{ + ChannelId: channelId, + Message: text, + CreateAt: model.GetMillis(), + } + scs.app.SendEphemeralPost(userId, ephemeral) +} + +// onClusterLeaderChange is called whenever the cluster leader may have changed. +func (scs *Service) onClusterLeaderChange() { + if scs.server.IsLeader() { + scs.resume() + } else { + scs.pause() + } +} + +func (scs *Service) resume() { + scs.mux.Lock() + defer scs.mux.Unlock() + + if scs.active { + return // already active + } + + scs.active = true + scs.done = make(chan struct{}) + + go scs.syncLoop(scs.done) + + scs.server.GetLogger().Debug("Shared Channel Service active") +} + +func (scs *Service) pause() { + scs.mux.Lock() + defer scs.mux.Unlock() + + if !scs.active { + return // already inactive + } + + scs.active = false + close(scs.done) + scs.done = nil + + scs.server.GetLogger().Debug("Shared Channel Service inactive") +} + +// Makes the remote channel to be read-only(announcement mode, only admins can create posts and reactions). +func (scs *Service) makeChannelReadOnly(channel *model.Channel) *model.AppError { + createPostPermission := model.ChannelModeratedPermissionsMap[model.PERMISSION_CREATE_POST.Id] + createReactionPermission := model.ChannelModeratedPermissionsMap[model.PERMISSION_ADD_REACTION.Id] + updateMap := model.ChannelModeratedRolesPatch{ + Guests: model.NewBool(false), + Members: model.NewBool(false), + } + + readonlyChannelModerations := []*model.ChannelModerationPatch{ + { + Name: &createPostPermission, + Roles: &updateMap, + }, + { + Name: &createReactionPermission, + Roles: &updateMap, + }, + } + + _, err := scs.app.PatchChannelModerationsForChannel(channel, readonlyChannelModerations) + return err +} + +// onConnectionStateChange is called whenever the connection state of a remote cluster changes, +// for example when one comes back online. +func (scs *Service) onConnectionStateChange(rc *model.RemoteCluster, online bool) { + if online { + // when a previously offline remote comes back online force a sync. + scs.ForceSyncForRemote(rc) + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Remote cluster connection status changed", + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rc.RemoteId), + mlog.Bool("online", online), + ) +} diff --git a/services/sharedchannel/sync_recv.go b/services/sharedchannel/sync_recv.go new file mode 100644 index 00000000000..8e837849be8 --- /dev/null +++ b/services/sharedchannel/sync_recv.go @@ -0,0 +1,296 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { + if msg.Topic != TopicSync { + return fmt.Errorf("wrong topic, expected `%s`, got `%s`", TopicSync, msg.Topic) + } + + if len(msg.Payload) == 0 { + return errors.New("empty sync message") + } + + if scs.server.GetLogger().IsLevelEnabled(mlog.LvlSharedChannelServiceMessagesInbound) { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceMessagesInbound, "inbound message", + mlog.String("remote", rc.DisplayName), + mlog.String("msg", string(msg.Payload)), + ) + } + + var syncMessages []syncMsg + + if err := json.Unmarshal(msg.Payload, &syncMessages); err != nil { + return fmt.Errorf("invalid sync message: %w", err) + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Batch of sync messages received", + mlog.String("remote", rc.DisplayName), + mlog.Int("sync_msg_count", len(syncMessages)), + ) + + return scs.processSyncMessages(syncMessages, rc, response) +} + +func (scs *Service) processSyncMessages(syncMessages []syncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { + var channel *model.Channel + var team *model.Team + + postErrors := make([]string, 0) + usersSyncd := make([]string, 0) + var lastSyncAt int64 + var err error + + for _, sm := range syncMessages { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Sync msg received", + mlog.String("post_id", sm.PostId), + mlog.String("channel_id", sm.ChannelId), + mlog.Int("reaction_count", len(sm.Reactions)), + mlog.Int("user_count", len(sm.Users)), + mlog.Bool("has_post", sm.Post != nil), + ) + + if channel == nil { + if channel, err = scs.server.GetStore().Channel().Get(sm.ChannelId, true); err != nil { + // if the channel doesn't exist then none of these sync messages are going to work. + return fmt.Errorf("channel not found processing sync messages: %w", err) + } + } + + // add/update users before posts + for _, user := range sm.Users { + if userSaved, err := scs.upsertSyncUser(user, channel, rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync user", + mlog.String("post_id", sm.PostId), + mlog.String("channel_id", sm.ChannelId), + mlog.String("user_id", user.Id), + mlog.Err(err)) + } else { + usersSyncd = append(usersSyncd, userSaved.Id) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "User upserted via sync", + mlog.String("post_id", sm.PostId), + mlog.String("channel_id", sm.ChannelId), + mlog.String("user_id", user.Id), + ) + } + } + + if sm.Post != nil { + if sm.ChannelId != sm.Post.ChannelId { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "ChannelId mismatch", + mlog.String("sm.ChannelId", sm.ChannelId), + mlog.String("sm.Post.ChannelId", sm.Post.ChannelId), + mlog.String("PostId", sm.Post.Id), + ) + postErrors = append(postErrors, sm.Post.Id) + continue + } + + if channel.Type != model.CHANNEL_DIRECT && team == nil { + var err2 error + team, err2 = scs.server.GetStore().Channel().GetTeamForChannel(sm.ChannelId) + if err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error getting Team for Channel", + mlog.String("ChannelId", sm.Post.ChannelId), + mlog.String("PostId", sm.Post.Id), + mlog.Err(err2), + ) + postErrors = append(postErrors, sm.Post.Id) + continue + } + } + + // process perma-links for remote + if team != nil { + sm.Post.Message = scs.processPermalinkFromRemote(sm.Post, team) + } + + // add/update post (may be nil if only reactions changed) + rpost, err := scs.upsertSyncPost(sm.Post, channel, rc) + if err != nil { + postErrors = append(postErrors, sm.Post.Id) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync post", + mlog.String("post_id", sm.Post.Id), + mlog.String("channel_id", sm.Post.ChannelId), + mlog.Err(err), + ) + } else if lastSyncAt < rpost.UpdateAt { + lastSyncAt = rpost.UpdateAt + } + } + + // add/remove reactions + for _, reaction := range sm.Reactions { + if _, err := scs.upsertSyncReaction(reaction, rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync reaction", + mlog.String("user_id", reaction.UserId), + mlog.String("post_id", reaction.PostId), + mlog.String("emoji", reaction.EmojiName), + mlog.Int64("delete_at", reaction.DeleteAt), + mlog.Err(err), + ) + } else { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Reaction upserted via sync", + mlog.String("user_id", reaction.UserId), + mlog.String("post_id", reaction.PostId), + mlog.String("emoji", reaction.EmojiName), + mlog.Int64("delete_at", reaction.DeleteAt), + ) + + if lastSyncAt < reaction.UpdateAt { + lastSyncAt = reaction.UpdateAt + } + } + } + } + + syncResp := SyncResponse{ + LastSyncAt: lastSyncAt, // might be zero + PostErrors: postErrors, // might be empty + UsersSyncd: usersSyncd, // might be empty + } + + response.SetPayload(syncResp) + + return nil +} + +func (scs *Service) upsertSyncUser(user *model.User, channel *model.Channel, rc *model.RemoteCluster) (*model.User, error) { + var err error + var userSaved *model.User + + user.RemoteId = model.NewString(rc.RemoteId) + + // does the user already exist? + euser, err := scs.server.GetStore().User().Get(context.Background(), user.Id) + if err != nil { + if _, ok := err.(errNotFound); !ok { + return nil, fmt.Errorf("error checking sync user: %w", err) + } + } + + if euser == nil { + if userSaved, err = scs.server.GetStore().User().Save(user); err != nil { + if e, ok := err.(errInvalidInput); ok { + _, field, value := e.InvalidInputInfo() + if field == "email" || field == "username" { + // username or email collision + // TODO: handle collision by modifying username/email (MM-32133) + return nil, fmt.Errorf("collision inserting sync user (%s=%s): %w", field, value, err) + } + } + return nil, fmt.Errorf("error inserting sync user: %w", err) + } + } else { + patch := &model.UserPatch{ + Nickname: &user.Nickname, + FirstName: &user.FirstName, + LastName: &user.LastName, + Position: &user.Position, + Locale: &user.Locale, + Timezone: user.Timezone, + RemoteId: user.RemoteId, + } + euser.Patch(patch) + userUpdated, err := scs.server.GetStore().User().Update(euser, false) + if err != nil { + return nil, fmt.Errorf("error updating sync user: %w", err) + } + userSaved = userUpdated.New + } + + // add user to team. We do this here regardless of whether the user was + // just created or patched since there are three steps to adding a user + // (insert rec, add to team, add to channel) and any one could fail. + // Instead of undoing what succeeded on any failure we simply do all steps each + // time. AddUserToChannel & AddUserToTeamByTeamId do not error if user already + // added and exit quickly. + if err := scs.app.AddUserToTeamByTeamId(channel.TeamId, userSaved); err != nil { + return nil, fmt.Errorf("error adding sync user to Team: %w", err) + } + + // add user to channel + if _, err := scs.app.AddUserToChannel(userSaved, channel); err != nil { + return nil, fmt.Errorf("error adding sync user to ChannelMembers: %w", err) + } + return userSaved, nil +} + +func (scs *Service) upsertSyncPost(post *model.Post, channel *model.Channel, rc *model.RemoteCluster) (*model.Post, error) { + var appErr *model.AppError + + post.RemoteId = model.NewString(rc.RemoteId) + + rpost, err := scs.server.GetStore().Post().GetSingle(post.Id, true) + if err != nil { + if _, ok := err.(errNotFound); !ok { + return nil, fmt.Errorf("error checking sync post: %w", err) + } + } + + if rpost == nil { + // post doesn't exist; create new one + rpost, appErr = scs.app.CreatePost(post, channel, true, true) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Created sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } else if post.DeleteAt > 0 { + // delete post + rpost, appErr = scs.app.DeletePost(post.Id, post.UserId) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Deleted sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } else if post.EditAt > rpost.EditAt || post.Message != rpost.Message { + // update post + rpost, appErr = scs.app.UpdatePost(post, false) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Updated sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } else { + // nothing to update + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Update to sync post ignored", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } + + var rerr error + if appErr != nil { + rerr = errors.New(appErr.Error()) + } + return rpost, rerr +} + +func (scs *Service) upsertSyncReaction(reaction *model.Reaction, rc *model.RemoteCluster) (*model.Reaction, error) { + savedReaction := reaction + var appErr *model.AppError + + reaction.RemoteId = model.NewString(rc.RemoteId) + + if reaction.DeleteAt == 0 { + savedReaction, appErr = scs.app.SaveReactionForPost(reaction) + } else { + appErr = scs.app.DeleteReactionForPost(reaction) + } + + var err error + if appErr != nil { + err = errors.New(appErr.Error()) + } + return savedReaction, err +} diff --git a/services/sharedchannel/sync_send.go b/services/sharedchannel/sync_send.go new file mode 100644 index 00000000000..251d9f8d89b --- /dev/null +++ b/services/sharedchannel/sync_send.go @@ -0,0 +1,512 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/i18n" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +type syncTask struct { + id string + channelId string + remoteId string + AddedAt time.Time + retryCount int + retryPost *model.Post + schedule time.Time +} + +func newSyncTask(channelId string, remoteId string, retryPost *model.Post) syncTask { + var postId string + if retryPost != nil { + postId = retryPost.Id + } + + return syncTask{ + id: channelId + remoteId + postId, // combination of ids to avoid duplicates + channelId: channelId, + remoteId: remoteId, // empty means update all remote clusters + retryPost: retryPost, + schedule: time.Now(), + } +} + +// incRetry increments the retry counter and returns true if MaxRetries not exceeded. +func (st *syncTask) incRetry() bool { + st.retryCount++ + return st.retryCount <= MaxRetries +} + +// NotifyChannelChanged is called to indicate that a shared channel has been modified, +// thus triggering an update to all remote clusters. +func (scs *Service) NotifyChannelChanged(channelId string) { + if rcs := scs.server.GetRemoteClusterService(); rcs == nil { + return + } + + task := newSyncTask(channelId, "", nil) + task.schedule = time.Now().Add(NotifyMinimumDelay) + scs.addTask(task) +} + +// ForceSyncForRemote causes all channels shared with the remote to be synchronized. +func (scs *Service) ForceSyncForRemote(rc *model.RemoteCluster) { + if rcs := scs.server.GetRemoteClusterService(); rcs == nil { + return + } + + // fetch all channels shared with this remote. + opts := model.SharedChannelRemoteFilterOpts{ + RemoteId: rc.RemoteId, + } + scrs, err := scs.server.GetStore().SharedChannel().GetRemotes(opts) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Failed to fetch shared channel remotes", + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rc.RemoteId), + mlog.Err(err), + ) + return + } + + for _, scr := range scrs { + task := newSyncTask(scr.ChannelId, rc.RemoteId, nil) + task.schedule = time.Now().Add(NotifyMinimumDelay) + scs.addTask(task) + } +} + +// addTask adds or re-adds a task to the queue. +func (scs *Service) addTask(task syncTask) { + task.AddedAt = time.Now() + scs.mux.Lock() + if _, ok := scs.tasks[task.id]; !ok { + scs.tasks[task.id] = task + } + scs.mux.Unlock() + + // wake up the sync goroutine + select { + case scs.changeSignal <- struct{}{}: + default: + // that's ok, the sync routine is already busy + } +} + +// syncLoop is called via a dedicated goroutine to wait for notifications of channel changes and +// updates each remote based on those changes. +func (scs *Service) syncLoop(done chan struct{}) { + // create a timer to periodically check the task queue, but only if there is + // a delayed task in the queue. + delay := time.NewTimer(NotifyMinimumDelay) + defer stopTimer(delay) + + // wait for channel changed signal and update for oldest task. + for { + select { + case <-scs.changeSignal: + if wait := scs.doSync(); wait > 0 { + stopTimer(delay) + delay.Reset(wait) + } + case <-delay.C: + if wait := scs.doSync(); wait > 0 { + delay.Reset(wait) + } + case <-done: + return + } + } +} + +func stopTimer(timer *time.Timer) { + timer.Stop() + select { + case <-timer.C: + default: + } +} + +// doSync checks the task queue for any tasks to be processed and processes all that are ready. +// If any delayed tasks remain in queue then the duration until the next scheduled task is returned. +func (scs *Service) doSync() time.Duration { + var task syncTask + var ok bool + var shortestWait time.Duration + + for { + task, ok, shortestWait = scs.removeOldestTask() + if !ok { + break + } + if err := scs.processTask(task); err != nil { + // put task back into map so it will update again + if task.incRetry() { + scs.addTask(task) + } else { + scs.server.GetLogger().Error("Failed to synchronize shared channel", + mlog.String("channelId", task.channelId), + mlog.String("remoteId", task.remoteId), + mlog.Err(err), + ) + } + } + } + return shortestWait +} + +// removeOldestTask removes and returns the oldest task in the task map. +// A task coming in via NotifyChannelChanged must stay in queue for at least +// `NotifyMinimumDelay` to ensure we don't go nuts trying to sync during a bulk update. +// If no tasks are available then false is returned. +func (scs *Service) removeOldestTask() (syncTask, bool, time.Duration) { + scs.mux.Lock() + defer scs.mux.Unlock() + + var oldestTask syncTask + var oldestKey string + var shortestWait time.Duration + + for key, task := range scs.tasks { + // check if task is ready + if wait := time.Until(task.schedule); wait > 0 { + if wait < shortestWait || shortestWait == 0 { + shortestWait = wait + } + continue + } + // task is ready; check if it's the oldest ready task + if task.AddedAt.Before(oldestTask.AddedAt) || oldestTask.AddedAt.IsZero() { + oldestKey = key + oldestTask = task + } + } + + if oldestKey != "" { + delete(scs.tasks, oldestKey) + return oldestTask, true, shortestWait + } + return oldestTask, false, shortestWait +} + +// processTask updates one or more remote clusters with any new channel content. +func (scs *Service) processTask(task syncTask) error { + var err error + var remotes []*model.RemoteCluster + + if task.remoteId == "" { + filter := model.RemoteClusterQueryFilter{ + InChannel: task.channelId, + OnlyConfirmed: true, + } + remotes, err = scs.server.GetStore().RemoteCluster().GetAll(filter) + if err != nil { + return err + } + } else { + rc, err := scs.server.GetStore().RemoteCluster().Get(task.remoteId) + if err != nil { + return err + } + if !rc.IsOnline() { + return fmt.Errorf("Failed updating shared channel '%s' for offline remote cluster '%s'", task.channelId, rc.DisplayName) + } + remotes = []*model.RemoteCluster{rc} + } + + for _, rc := range remotes { + rtask := task + rtask.remoteId = rc.RemoteId + if err := scs.updateForRemote(rtask, rc); err != nil { + // retry... + if rtask.incRetry() { + scs.addTask(rtask) + } else { + scs.server.GetLogger().Error("Failed to synchronize shared channel for remote cluster", + mlog.String("channelId", rtask.channelId), + mlog.String("remote", rc.DisplayName), + mlog.String("remoteId", rtask.remoteId), + mlog.Err(err), + ) + } + } + } + return nil +} + +// updateForRemote updates a remote cluster with any new posts/reactions for a specific +// channel. If many changes are found, only the oldest X changes are sent and the channel +// is re-added to the task map. This ensures no channels are starved for updates even if some +// channels are very active. +func (scs *Service) updateForRemote(task syncTask, rc *model.RemoteCluster) error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return fmt.Errorf("cannot update remote cluster for channel id %s; Remote Cluster Service not enabled", task.channelId) + } + + scr, err := scs.server.GetStore().SharedChannel().GetRemoteByIds(task.channelId, rc.RemoteId) + if err != nil { + return err + } + + var posts []*model.Post + var repeat bool + nextSince := scr.NextSyncAt + + if task.retryPost != nil { + posts = []*model.Post{task.retryPost} + } else { + result, err2 := scs.getPostsSince(task.channelId, rc, scr.NextSyncAt) + if err2 != nil { + return err2 + } + posts = result.posts + repeat = result.hasMore + nextSince = result.nextSince + } + + if len(posts) == 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task found zero posts; skipping sync", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", task.channelId), + mlog.Int64("lastSyncAt", scr.NextSyncAt), + mlog.Int64("nextSince", nextSince), + mlog.Bool("repeat", repeat), + ) + return nil + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task found posts to sync", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", task.channelId), + mlog.Int64("lastSyncAt", scr.NextSyncAt), + mlog.Int64("nextSince", nextSince), + mlog.Int("count", len(posts)), + mlog.Bool("repeat", repeat), + ) + + if !rc.IsOnline() { + scs.notifyRemoteOffline(posts, rc) + return nil + } + + syncMessages, err := scs.postsToSyncMessages(posts, rc, scr.NextSyncAt) + if err != nil { + return err + } + + if len(syncMessages) == 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task, all messages filtered out; skipping sync", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", task.channelId), + mlog.Bool("repeat", repeat), + ) + + // All posts were filtered out, meaning no need to send them. Fast forward SharedChannelRemote's NextSyncAt. + scs.updateNextSyncForRemote(scr.Id, rc, nextSince) + + // everything was filtered out, nothing to send. + if repeat { + scs.addTask(newSyncTask(task.channelId, task.remoteId, nil)) + } + return nil + } + + scs.sendAttachments(syncMessages, rc) + + b, err := json.Marshal(syncMessages) + if err != nil { + return err + } + msg := model.NewRemoteClusterMsg(TopicSync, b) + + if scs.server.GetLogger().IsLevelEnabled(mlog.LvlSharedChannelServiceMessagesOutbound) { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceMessagesOutbound, "outbound message", + mlog.String("remote", rc.DisplayName), + mlog.Int64("NextSyncAt", scr.NextSyncAt), + mlog.String("msg", string(b)), + ) + } + + ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + + err = rcs.SendMsg(ctx, msg, rc, func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { + defer wg.Done() + if err != nil { + return // this means the response could not be parsed; already logged + } + + var syncResp SyncResponse + if err2 := json.Unmarshal(resp.Payload, &syncResp); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "invalid sync response after update shared channel", + mlog.String("remote", rc.DisplayName), + mlog.Err(err2), + ) + } + + // Any Post(s) that failed to save on remote side are included in an array of post ids in the Response payload. + // Handle each error by retrying the post a fixed number of times before giving up. + for _, p := range syncResp.PostErrors { + scs.handlePostError(p, task, rc) + } + + // update NextSyncAt for all the users that were synchronized + scs.updateSyncUsers(syncResp.UsersSyncd, rc, nextSince) + }) + + wg.Wait() + + if err == nil { + // Optimistically update SharedChannelRemote's NextSyncAt; if any posts failed they will be retried. + scs.updateNextSyncForRemote(scr.Id, rc, nextSince) + } + + if repeat { + scs.addTask(newSyncTask(task.channelId, task.remoteId, nil)) + } + return err +} + +func (scs *Service) sendAttachments(syncMessages []syncMsg, rc *model.RemoteCluster) { + for _, sm := range syncMessages { + for _, fi := range sm.Attachments { + if err := scs.sendAttachmentForRemote(fi, sm.Post, rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error syncing attachment for post", + mlog.String("remote", rc.DisplayName), + mlog.String("post_id", sm.Post.Id), + mlog.String("file_id", fi.Id), + mlog.Err(err), + ) + } + } + } +} + +func (scs *Service) handlePostError(postId string, task syncTask, rc *model.RemoteCluster) { + if task.retryPost != nil && task.retryPost.Id == postId { + // this was a retry for specific post that failed previously. Try again if within MaxRetries. + if task.incRetry() { + scs.addTask(task) + } else { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error syncing post", + mlog.String("remote", rc.DisplayName), + mlog.String("post_id", postId), + ) + } + return + } + + // this post failed as part of a group of posts. Retry as an individual post. + post, err := scs.server.GetStore().Post().GetSingle(postId, true) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error fetching post for sync retry", + mlog.String("remote", rc.DisplayName), + mlog.String("post_id", postId), + ) + return + } + scs.addTask(newSyncTask(task.channelId, task.remoteId, post)) +} + +// notifyRemoteOffline creates an ephemeral post to the author for any posts created recently to remotes +// that are offline. +func (scs *Service) notifyRemoteOffline(posts []*model.Post, rc *model.RemoteCluster) { + // only send one ephemeral post per author. + notified := make(map[string]bool) + + // range the slice in reverse so the newest posts are visited first; this ensures an ephemeral + // get added where it is mostly likely to be seen. + for i := len(posts) - 1; i >= 0; i-- { + post := posts[i] + if didNotify := notified[post.UserId]; didNotify { + continue + } + + postCreateAt := model.GetTimeForMillis(post.CreateAt) + + if post.DeleteAt == 0 && post.UserId != "" && time.Since(postCreateAt) < NotifyRemoteOfflineThreshold { + T := scs.getUserTranslations(post.UserId) + ephemeral := &model.Post{ + ChannelId: post.ChannelId, + Message: T("sharedchannel.cannot_deliver_post", map[string]interface{}{"Remote": rc.DisplayName}), + CreateAt: post.CreateAt + 1, + } + scs.app.SendEphemeralPost(post.UserId, ephemeral) + + notified[post.UserId] = true + } + } +} + +func (scs *Service) updateNextSyncForRemote(scrId string, rc *model.RemoteCluster, nextSyncAt int64) { + if nextSyncAt == 0 { + return + } + if err := scs.server.GetStore().SharedChannel().UpdateRemoteNextSyncAt(scrId, nextSyncAt); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error updating NextSyncAt for shared channel remote", + mlog.String("remote", rc.DisplayName), + mlog.Err(err), + ) + return + } + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "updated NextSyncAt for remote", + mlog.String("remote_id", rc.RemoteId), + mlog.String("remote", rc.DisplayName), + mlog.Int64("next_update_at", nextSyncAt), + ) +} + +func (scs *Service) updateSyncUsers(userIds []string, rc *model.RemoteCluster, lastSyncAt int64) { + for _, uid := range userIds { + scu, err := scs.server.GetStore().SharedChannel().GetUser(uid, rc.RemoteId) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error getting user for lastSyncAt update", + mlog.String("remote", rc.DisplayName), + mlog.String("user_id", uid), + mlog.Err(err), + ) + continue + } + + if err := scs.server.GetStore().SharedChannel().UpdateUserLastSyncAt(scu.Id, lastSyncAt); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error updating lastSyncAt for user", + mlog.String("remote", rc.DisplayName), + mlog.String("user_id", uid), + mlog.Err(err), + ) + } else { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "updated lastSyncAt for user", + mlog.String("remote", rc.DisplayName), + mlog.String("user_id", scu.UserId), + mlog.Int64("last_update_at", lastSyncAt), + ) + } + } +} + +func (scs *Service) getUserTranslations(userId string) i18n.TranslateFunc { + var locale string + user, err := scs.server.GetStore().User().Get(context.Background(), userId) + if err == nil { + locale = user.Locale + } + + if locale == "" { + locale = model.DEFAULT_LOCALE + } + return i18n.GetUserTranslations(locale) +} diff --git a/services/slackimport/slackimport.go b/services/slackimport/slackimport.go index 50ce3f683e7..9f44c50f4b4 100644 --- a/services/slackimport/slackimport.go +++ b/services/slackimport/slackimport.go @@ -85,7 +85,7 @@ type Actions struct { UpdateActive func(*model.User, bool) (*model.User, *model.AppError) AddUserToChannel func(*model.User, *model.Channel) (*model.ChannelMember, *model.AppError) JoinUserToTeam func(*model.Team, *model.User, string) *model.AppError - CreateDirectChannel func(string, string) (*model.Channel, *model.AppError) + CreateDirectChannel func(string, string, ...model.ChannelOption) (*model.Channel, *model.AppError) CreateGroupChannel func([]string) (*model.Channel, *model.AppError) CreateChannel func(*model.Channel, bool) (*model.Channel, *model.AppError) DoUploadFile func(time.Time, string, string, string, string, []byte) (*model.FileInfo, *model.AppError) diff --git a/services/telemetry/telemetry.go b/services/telemetry/telemetry.go index ab563702130..56a8466adb5 100644 --- a/services/telemetry/telemetry.go +++ b/services/telemetry/telemetry.go @@ -734,6 +734,7 @@ func (ts *TelemetryService) trackConfig() { "cloud_billing": *cfg.ExperimentalSettings.CloudBilling, "cloud_user_limit": *cfg.ExperimentalSettings.CloudUserLimit, "enable_shared_channels": *cfg.ExperimentalSettings.EnableSharedChannels, + "enable_remote_cluster_service": *cfg.ExperimentalSettings.EnableRemoteClusterService && cfg.FeatureFlags.EnableRemoteClusterService, }) ts.sendTelemetry(TrackConfigAnalytics, map[string]interface{}{ diff --git a/shared/mlog/default.go b/shared/mlog/default.go index 1e409b192c4..e7faa8c4c12 100644 --- a/shared/mlog/default.go +++ b/shared/mlog/default.go @@ -33,6 +33,10 @@ func defaultLog(level, msg string, fields ...Field) { } } +func defaultIsLevelEnabled(level LogLevel) bool { + return true +} + func defaultDebugLog(msg string, fields ...Field) { defaultLog("debug", msg, fields...) } diff --git a/shared/mlog/global.go b/shared/mlog/global.go index 2986d92d297..aba06646724 100644 --- a/shared/mlog/global.go +++ b/shared/mlog/global.go @@ -23,6 +23,7 @@ func InitGlobalLogger(logger *Logger) { glob := *logger glob.zap = glob.zap.WithOptions(zap.AddCallerSkip(1)) globalLogger = &glob + IsLevelEnabled = globalLogger.IsLevelEnabled Debug = globalLogger.Debug Info = globalLogger.Info Warn = globalLogger.Warn @@ -59,6 +60,7 @@ func RedirectStdLog(logger *Logger) { log.SetOutput(logWriterFunc(writer)) } +type IsLevelEnabledFunc func(LogLevel) bool type LogFunc func(string, ...Field) type LogFuncCustom func(LogLevel, string, ...Field) type LogFuncCustomMulti func([]LogLevel, string, ...Field) @@ -79,6 +81,7 @@ func GloballyEnableDebugLogForTest() { globalLogger.consoleLevel.SetLevel(zapcore.DebugLevel) } +var IsLevelEnabled IsLevelEnabledFunc = defaultIsLevelEnabled var Debug LogFunc = defaultDebugLog var Info LogFunc = defaultInfoLog var Warn LogFunc = defaultWarnLog diff --git a/shared/mlog/levels.go b/shared/mlog/levels.go index 54bd25496e1..24d29e0bee1 100644 --- a/shared/mlog/levels.go +++ b/shared/mlog/levels.go @@ -30,6 +30,18 @@ var ( // used by the TCP log target LvlTcpLogTarget = LogLevel{ID: 120, Name: "TcpLogTarget"} + // used by Remote Cluster Service + LvlRemoteClusterServiceDebug = LogLevel{ID: 130, Name: "RemoteClusterServiceDebug"} + LvlRemoteClusterServiceError = LogLevel{ID: 131, Name: "RemoteClusterServiceError"} + LvlRemoteClusterServiceWarn = LogLevel{ID: 132, Name: "RemoteClusterServiceWarn"} + + // used by Shared Channel Sync Service + LvlSharedChannelServiceDebug = LogLevel{ID: 200, Name: "SharedChannelServiceDebug"} + LvlSharedChannelServiceError = LogLevel{ID: 201, Name: "SharedChannelServiceError"} + LvlSharedChannelServiceWarn = LogLevel{ID: 202, Name: "SharedChannelServiceWarn"} + LvlSharedChannelServiceMessagesInbound = LogLevel{ID: 203, Name: "SharedChannelServiceMsgInbound"} + LvlSharedChannelServiceMessagesOutbound = LogLevel{ID: 204, Name: "SharedChannelServiceMsgOutbound"} + // add more here ... ) diff --git a/shared/mlog/log.go b/shared/mlog/log.go index 6395240a1d4..ade786879c4 100644 --- a/shared/mlog/log.go +++ b/shared/mlog/log.go @@ -57,6 +57,17 @@ var NamedErr = zap.NamedError var Bool = zap.Bool var Duration = zap.Duration +type LoggerIFace interface { + IsLevelEnabled(LogLevel) bool + Debug(string, ...Field) + Info(string, ...Field) + Warn(string, ...Field) + Error(string, ...Field) + Critical(string, ...Field) + Log(LogLevel, string, ...Field) + LogM([]LogLevel, string, ...Field) +} + type TargetInfo logr.TargetInfo type LoggerConfiguration struct { @@ -207,6 +218,10 @@ func (l *Logger) Sugar() *SugarLogger { } } +func (l *Logger) IsLevelEnabled(level LogLevel) bool { + return isLevelEnabled(l.getLogger(), logr.Level(level)) +} + func (l *Logger) Debug(message string, fields ...Field) { l.zap.Debug(message, fields...) if isLevelEnabled(l.getLogger(), logr.Debug) { diff --git a/store/errors.go b/store/errors.go index fac07ee0e05..ee9e896713f 100644 --- a/store/errors.go +++ b/store/errors.go @@ -26,6 +26,13 @@ func (e *ErrInvalidInput) Error() string { return fmt.Sprintf("invalid input: entity: %s field: %s value: %s", e.Entity, e.Field, e.Value) } +func (e *ErrInvalidInput) InvalidInputInfo() (entity string, field string, value interface{}) { + entity = e.Entity + field = e.Field + value = e.Value + return +} + // ErrLimitExceeded indicates an error that has occurred because some value exceeded a limit. type ErrLimitExceeded struct { What string // What was the object that exceeded. @@ -72,6 +79,11 @@ func (e *ErrConflict) Unwrap() error { return e.err } +// IsErrConflict allows easy type assertion without adding store as a dependency. +func (e *ErrConflict) IsErrConflict() bool { + return true +} + // ErrNotFound indicates that a resource was not found type ErrNotFound struct { resource string @@ -89,6 +101,11 @@ func (e *ErrNotFound) Error() string { return "resource: " + e.resource + " id: " + e.ID } +// IsErrNotFound allows easy type assertion without adding store as a dependency. +func (e *ErrNotFound) IsErrNotFound() bool { + return true +} + // ErrOutOfBounds indicates that the requested total numbers of rows // was greater than the allowed limit. type ErrOutOfBounds struct { diff --git a/store/layer_generators/main.go b/store/layer_generators/main.go index 940bdfe3e9b..7bc920064ce 100644 --- a/store/layer_generators/main.go +++ b/store/layer_generators/main.go @@ -270,18 +270,23 @@ func generateLayer(name, templateFile string) ([]byte, error) { "joinParams": func(params []methodParam) string { paramsNames := make([]string, 0, len(params)) for _, param := range params { - paramsNames = append(paramsNames, param.Name) + tParams := "" + if strings.HasPrefix(param.Type, "...") { + tParams = "..." + } + paramsNames = append(paramsNames, param.Name+tParams) } return strings.Join(paramsNames, ", ") }, "joinParamsWithType": func(params []methodParam) string { paramsWithType := []string{} for _, param := range params { - if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" { + switch param.Type { + case "ChannelSearchOpts", "UserGetByIdsOpts": paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) - } else if param.Type == "*UserGetByIdsOpts" { + case "*UserGetByIdsOpts": paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name)) - } else { + default: paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) } } @@ -290,11 +295,12 @@ func generateLayer(name, templateFile string) ([]byte, error) { "joinParamsWithTypeOutsideStore": func(params []methodParam) string { paramsWithType := []string{} for _, param := range params { - if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" { + switch param.Type { + case "ChannelSearchOpts", "UserGetByIdsOpts": paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) - } else if param.Type == "*UserGetByIdsOpts" { + case "*UserGetByIdsOpts": paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name)) - } else { + default: paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) } } diff --git a/store/opentracinglayer/opentracinglayer.go b/store/opentracinglayer/opentracinglayer.go index 5d578fbbc6f..e13949ff3ef 100644 --- a/store/opentracinglayer/opentracinglayer.go +++ b/store/opentracinglayer/opentracinglayer.go @@ -38,9 +38,11 @@ type OpenTracingLayer struct { PreferenceStore store.PreferenceStore ProductNoticesStore store.ProductNoticesStore ReactionStore store.ReactionStore + RemoteClusterStore store.RemoteClusterStore RoleStore store.RoleStore SchemeStore store.SchemeStore SessionStore store.SessionStore + SharedChannelStore store.SharedChannelStore StatusStore store.StatusStore SystemStore store.SystemStore TeamStore store.TeamStore @@ -134,6 +136,10 @@ func (s *OpenTracingLayer) Reaction() store.ReactionStore { return s.ReactionStore } +func (s *OpenTracingLayer) RemoteCluster() store.RemoteClusterStore { + return s.RemoteClusterStore +} + func (s *OpenTracingLayer) Role() store.RoleStore { return s.RoleStore } @@ -146,6 +152,10 @@ func (s *OpenTracingLayer) Session() store.SessionStore { return s.SessionStore } +func (s *OpenTracingLayer) SharedChannel() store.SharedChannelStore { + return s.SharedChannelStore +} + func (s *OpenTracingLayer) Status() store.StatusStore { return s.StatusStore } @@ -290,6 +300,11 @@ type OpenTracingLayerReactionStore struct { Root *OpenTracingLayer } +type OpenTracingLayerRemoteClusterStore struct { + store.RemoteClusterStore + Root *OpenTracingLayer +} + type OpenTracingLayerRoleStore struct { store.RoleStore Root *OpenTracingLayer @@ -305,6 +320,11 @@ type OpenTracingLayerSessionStore struct { Root *OpenTracingLayer } +type OpenTracingLayerSharedChannelStore struct { + store.SharedChannelStore + Root *OpenTracingLayer +} + type OpenTracingLayerStatusStore struct { store.StatusStore Root *OpenTracingLayer @@ -643,7 +663,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta return result, resultVar1, err } -func (s *OpenTracingLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User) (*model.Channel, error) { +func (s *OpenTracingLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CreateDirectChannel") s.Root.Store.SetContext(newCtx) @@ -652,7 +672,7 @@ func (s *OpenTracingLayerChannelStore) CreateDirectChannel(userId *model.User, o }() defer span.Finish() - result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId) + result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId, channelOptions...) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -1538,6 +1558,24 @@ func (s *OpenTracingLayerChannelStore) GetTeamChannels(teamID string) (*model.Ch return result, err } +func (s *OpenTracingLayerChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GetTeamForChannel") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.ChannelStore.GetTeamForChannel(channelID) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerChannelStore) GroupSyncedChannelCount() (int64, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GroupSyncedChannelCount") @@ -2074,6 +2112,24 @@ func (s *OpenTracingLayerChannelStore) SetDeleteAt(channelID string, deleteAt in return err } +func (s *OpenTracingLayerChannelStore) SetShared(channelId string, shared bool) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.SetShared") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.ChannelStore.SetShared(channelId, shared) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + func (s *OpenTracingLayerChannelStore) Update(channel *model.Channel) (*model.Channel, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.Update") @@ -5292,6 +5348,24 @@ func (s *OpenTracingLayerPostStore) GetPostsSince(options model.GetPostsSinceOpt return result, err } +func (s *OpenTracingLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.GetPostsSinceForSync") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerPostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.GetRepliesForExport") @@ -5310,7 +5384,7 @@ func (s *OpenTracingLayerPostStore) GetRepliesForExport(parentID string) ([]*mod return result, err } -func (s *OpenTracingLayerPostStore) GetSingle(id string) (*model.Post, error) { +func (s *OpenTracingLayerPostStore) GetSingle(id string, inclDeleted bool) (*model.Post, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.GetSingle") s.Root.Store.SetContext(newCtx) @@ -5319,7 +5393,7 @@ func (s *OpenTracingLayerPostStore) GetSingle(id string) (*model.Post, error) { }() defer span.Finish() - result, err := s.PostStore.GetSingle(id) + result, err := s.PostStore.GetSingle(id, inclDeleted) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -5827,6 +5901,24 @@ func (s *OpenTracingLayerReactionStore) GetForPost(postID string, allowFromCache return result, err } +func (s *OpenTracingLayerReactionStore) GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ReactionStore.GetForPostSince") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.ReactionStore.GetForPostSince(postId, since, excludeRemoteId, inclDeleted) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerReactionStore) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ReactionStore.PermanentDeleteBatch") @@ -5863,6 +5955,132 @@ func (s *OpenTracingLayerReactionStore) Save(reaction *model.Reaction) (*model.R return result, err } +func (s *OpenTracingLayerRemoteClusterStore) Delete(remoteClusterId string) (bool, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.Delete") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.Delete(remoteClusterId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerRemoteClusterStore) Get(remoteClusterId string) (*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.Get") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.Get(remoteClusterId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerRemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.GetAll") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.GetAll(filter) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerRemoteClusterStore) Save(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.Save") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.Save(rc) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerRemoteClusterStore) SetLastPingAt(remoteClusterId string) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.SetLastPingAt") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.RemoteClusterStore.SetLastPingAt(remoteClusterId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + +func (s *OpenTracingLayerRemoteClusterStore) Update(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.Update") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.Update(rc) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerRemoteClusterStore) UpdateTopics(remoteClusterId string, topics string) (*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RemoteClusterStore.UpdateTopics") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.RemoteClusterStore.UpdateTopics(remoteClusterId, topics) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerRoleStore) AllChannelSchemeRoles() ([]*model.Role, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "RoleStore.AllChannelSchemeRoles") @@ -6470,6 +6688,438 @@ func (s *OpenTracingLayerSessionStore) UpdateRoles(userId string, roles string) return result, err } +func (s *OpenTracingLayerSharedChannelStore) Delete(channelId string) (bool, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.Delete") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.Delete(channelId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) DeleteRemote(remoteId string) (bool, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.DeleteRemote") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.DeleteRemote(remoteId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) Get(channelId string) (*model.SharedChannel, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.Get") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.Get(channelId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetAll(offset int, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetAll") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetAll(offset, limit, opts) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetAllCount") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetAllCount(opts) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetAttachment") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetAttachment(fileId, remoteId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetRemote(id string) (*model.SharedChannelRemote, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetRemote") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetRemote(id) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetRemoteByIds") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetRemoteByIds(channelId, remoteId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetRemoteForUser") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetRemoteForUser(remoteId, userId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetRemotes") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetRemotes(opts) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetRemotesStatus") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetRemotesStatus(channelId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetUser") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetUser(userId, remoteId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) HasChannel(channelID string) (bool, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.HasChannel") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.HasChannel(channelID) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) HasRemote(channelID string, remoteId string) (bool, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.HasRemote") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.HasRemote(channelID, remoteId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) Save(sc *model.SharedChannel) (*model.SharedChannel, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.Save") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.Save(sc) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.SaveAttachment") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.SaveAttachment(remote) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.SaveRemote") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.SaveRemote(remote) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.SaveUser") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.SaveUser(remote) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) Update(sc *model.SharedChannel) (*model.SharedChannel, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.Update") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.Update(sc) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) UpdateAttachmentLastSyncAt(id string, syncTime int64) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateAttachmentLastSyncAt") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.SharedChannelStore.UpdateAttachmentLastSyncAt(id, syncTime) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + +func (s *OpenTracingLayerSharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateRemote") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.UpdateRemote(remote) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateRemoteNextSyncAt") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + +func (s *OpenTracingLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateUserLastSyncAt") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + +func (s *OpenTracingLayerSharedChannelStore) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpsertAttachment") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.UpsertAttachment(remote) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerStatusStore) Get(userId string) (*model.Status, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "StatusStore.Get") @@ -10201,9 +10851,11 @@ func New(childStore store.Store, ctx context.Context) *OpenTracingLayer { newStore.PreferenceStore = &OpenTracingLayerPreferenceStore{PreferenceStore: childStore.Preference(), Root: &newStore} newStore.ProductNoticesStore = &OpenTracingLayerProductNoticesStore{ProductNoticesStore: childStore.ProductNotices(), Root: &newStore} newStore.ReactionStore = &OpenTracingLayerReactionStore{ReactionStore: childStore.Reaction(), Root: &newStore} + newStore.RemoteClusterStore = &OpenTracingLayerRemoteClusterStore{RemoteClusterStore: childStore.RemoteCluster(), Root: &newStore} newStore.RoleStore = &OpenTracingLayerRoleStore{RoleStore: childStore.Role(), Root: &newStore} newStore.SchemeStore = &OpenTracingLayerSchemeStore{SchemeStore: childStore.Scheme(), Root: &newStore} newStore.SessionStore = &OpenTracingLayerSessionStore{SessionStore: childStore.Session(), Root: &newStore} + newStore.SharedChannelStore = &OpenTracingLayerSharedChannelStore{SharedChannelStore: childStore.SharedChannel(), Root: &newStore} newStore.StatusStore = &OpenTracingLayerStatusStore{StatusStore: childStore.Status(), Root: &newStore} newStore.SystemStore = &OpenTracingLayerSystemStore{SystemStore: childStore.System(), Root: &newStore} newStore.TeamStore = &OpenTracingLayerTeamStore{TeamStore: childStore.Team(), Root: &newStore} diff --git a/store/retrylayer/retrylayer.go b/store/retrylayer/retrylayer.go index 7a358fd5b3e..6e5b7b5cb19 100644 --- a/store/retrylayer/retrylayer.go +++ b/store/retrylayer/retrylayer.go @@ -40,9 +40,11 @@ type RetryLayer struct { PreferenceStore store.PreferenceStore ProductNoticesStore store.ProductNoticesStore ReactionStore store.ReactionStore + RemoteClusterStore store.RemoteClusterStore RoleStore store.RoleStore SchemeStore store.SchemeStore SessionStore store.SessionStore + SharedChannelStore store.SharedChannelStore StatusStore store.StatusStore SystemStore store.SystemStore TeamStore store.TeamStore @@ -136,6 +138,10 @@ func (s *RetryLayer) Reaction() store.ReactionStore { return s.ReactionStore } +func (s *RetryLayer) RemoteCluster() store.RemoteClusterStore { + return s.RemoteClusterStore +} + func (s *RetryLayer) Role() store.RoleStore { return s.RoleStore } @@ -148,6 +154,10 @@ func (s *RetryLayer) Session() store.SessionStore { return s.SessionStore } +func (s *RetryLayer) SharedChannel() store.SharedChannelStore { + return s.SharedChannelStore +} + func (s *RetryLayer) Status() store.StatusStore { return s.StatusStore } @@ -292,6 +302,11 @@ type RetryLayerReactionStore struct { Root *RetryLayer } +type RetryLayerRemoteClusterStore struct { + store.RemoteClusterStore + Root *RetryLayer +} + type RetryLayerRoleStore struct { store.RoleStore Root *RetryLayer @@ -307,6 +322,11 @@ type RetryLayerSessionStore struct { Root *RetryLayer } +type RetryLayerSharedChannelStore struct { + store.SharedChannelStore + Root *RetryLayer +} + type RetryLayerStatusStore struct { store.StatusStore Root *RetryLayer @@ -684,11 +704,11 @@ func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int } -func (s *RetryLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User) (*model.Channel, error) { +func (s *RetryLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { tries := 0 for { - result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId) + result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId, channelOptions...) if err == nil { return result, nil } @@ -1670,6 +1690,26 @@ func (s *RetryLayerChannelStore) GetTeamChannels(teamID string) (*model.ChannelL } +func (s *RetryLayerChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) { + + tries := 0 + for { + result, err := s.ChannelStore.GetTeamForChannel(channelID) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + func (s *RetryLayerChannelStore) GroupSyncedChannelCount() (int64, error) { tries := 0 @@ -2198,6 +2238,26 @@ func (s *RetryLayerChannelStore) SetDeleteAt(channelID string, deleteAt int64, u } +func (s *RetryLayerChannelStore) SetShared(channelId string, shared bool) error { + + tries := 0 + for { + err := s.ChannelStore.SetShared(channelId, shared) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + } + +} + func (s *RetryLayerChannelStore) Update(channel *model.Channel) (*model.Channel, error) { tries := 0 @@ -5714,6 +5774,26 @@ func (s *RetryLayerPostStore) GetPostsSince(options model.GetPostsSinceOptions, } +func (s *RetryLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { + + tries := 0 + for { + result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + func (s *RetryLayerPostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { tries := 0 @@ -5734,11 +5814,11 @@ func (s *RetryLayerPostStore) GetRepliesForExport(parentID string) ([]*model.Rep } -func (s *RetryLayerPostStore) GetSingle(id string) (*model.Post, error) { +func (s *RetryLayerPostStore) GetSingle(id string, inclDeleted bool) (*model.Post, error) { tries := 0 for { - result, err := s.PostStore.GetSingle(id) + result, err := s.PostStore.GetSingle(id, inclDeleted) if err == nil { return result, nil } @@ -6300,6 +6380,26 @@ func (s *RetryLayerReactionStore) GetForPost(postID string, allowFromCache bool) } +func (s *RetryLayerReactionStore) GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) { + + tries := 0 + for { + result, err := s.ReactionStore.GetForPostSince(postId, since, excludeRemoteId, inclDeleted) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + func (s *RetryLayerReactionStore) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) { tries := 0 @@ -6340,6 +6440,146 @@ func (s *RetryLayerReactionStore) Save(reaction *model.Reaction) (*model.Reactio } +func (s *RetryLayerRemoteClusterStore) Delete(remoteClusterId string) (bool, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.Delete(remoteClusterId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) Get(remoteClusterId string) (*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.Get(remoteClusterId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.GetAll(filter) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) Save(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.Save(rc) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) SetLastPingAt(remoteClusterId string) error { + + tries := 0 + for { + err := s.RemoteClusterStore.SetLastPingAt(remoteClusterId) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) Update(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.Update(rc) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerRemoteClusterStore) UpdateTopics(remoteClusterId string, topics string) (*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.RemoteClusterStore.UpdateTopics(remoteClusterId, topics) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + func (s *RetryLayerRoleStore) AllChannelSchemeRoles() ([]*model.Role, error) { tries := 0 @@ -7006,6 +7246,486 @@ func (s *RetryLayerSessionStore) UpdateRoles(userId string, roles string) (strin } +func (s *RetryLayerSharedChannelStore) Delete(channelId string) (bool, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.Delete(channelId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) DeleteRemote(remoteId string) (bool, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.DeleteRemote(remoteId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) Get(channelId string) (*model.SharedChannel, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.Get(channelId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetAll(offset int, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetAll(offset, limit, opts) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetAllCount(opts) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetAttachment(fileId, remoteId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetRemote(id string) (*model.SharedChannelRemote, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetRemote(id) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetRemoteByIds(channelId, remoteId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetRemoteForUser(remoteId, userId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetRemotes(opts) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetRemotesStatus(channelId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetUser(userId, remoteId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) HasChannel(channelID string) (bool, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.HasChannel(channelID) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) HasRemote(channelID string, remoteId string) (bool, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.HasRemote(channelID, remoteId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) Save(sc *model.SharedChannel) (*model.SharedChannel, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.Save(sc) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.SaveAttachment(remote) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.SaveRemote(remote) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.SaveUser(remote) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) Update(sc *model.SharedChannel) (*model.SharedChannel, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.Update(sc) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) UpdateAttachmentLastSyncAt(id string, syncTime int64) error { + + tries := 0 + for { + err := s.SharedChannelStore.UpdateAttachmentLastSyncAt(id, syncTime) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + } + +} + +func (s *RetryLayerSharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.UpdateRemote(remote) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { + + tries := 0 + for { + err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + } + +} + +func (s *RetryLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { + + tries := 0 + for { + err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + } + +} + +func (s *RetryLayerSharedChannelStore) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.UpsertAttachment(remote) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + func (s *RetryLayerStatusStore) Get(userId string) (*model.Status, error) { tries := 0 @@ -11043,9 +11763,11 @@ func New(childStore store.Store) *RetryLayer { newStore.PreferenceStore = &RetryLayerPreferenceStore{PreferenceStore: childStore.Preference(), Root: &newStore} newStore.ProductNoticesStore = &RetryLayerProductNoticesStore{ProductNoticesStore: childStore.ProductNotices(), Root: &newStore} newStore.ReactionStore = &RetryLayerReactionStore{ReactionStore: childStore.Reaction(), Root: &newStore} + newStore.RemoteClusterStore = &RetryLayerRemoteClusterStore{RemoteClusterStore: childStore.RemoteCluster(), Root: &newStore} newStore.RoleStore = &RetryLayerRoleStore{RoleStore: childStore.Role(), Root: &newStore} newStore.SchemeStore = &RetryLayerSchemeStore{SchemeStore: childStore.Scheme(), Root: &newStore} newStore.SessionStore = &RetryLayerSessionStore{SessionStore: childStore.Session(), Root: &newStore} + newStore.SharedChannelStore = &RetryLayerSharedChannelStore{SharedChannelStore: childStore.SharedChannel(), Root: &newStore} newStore.StatusStore = &RetryLayerStatusStore{StatusStore: childStore.Status(), Root: &newStore} newStore.SystemStore = &RetryLayerSystemStore{SystemStore: childStore.System(), Root: &newStore} newStore.TeamStore = &RetryLayerTeamStore{TeamStore: childStore.Team(), Root: &newStore} diff --git a/store/retrylayer/retrylayer_test.go b/store/retrylayer/retrylayer_test.go index f62f2c1a288..77081f32cf8 100644 --- a/store/retrylayer/retrylayer_test.go +++ b/store/retrylayer/retrylayer_test.go @@ -21,6 +21,7 @@ func genStore() *mocks.Store { mock.On("Channel").Return(&mocks.ChannelStore{}) mock.On("ChannelMemberHistory").Return(&mocks.ChannelMemberHistoryStore{}) mock.On("ClusterDiscovery").Return(&mocks.ClusterDiscoveryStore{}) + mock.On("RemoteCluster").Return(&mocks.RemoteClusterStore{}) mock.On("Command").Return(&mocks.CommandStore{}) mock.On("CommandWebhook").Return(&mocks.CommandWebhookStore{}) mock.On("Compliance").Return(&mocks.ComplianceStore{}) @@ -31,6 +32,7 @@ func genStore() *mocks.Store { mock.On("Job").Return(&mocks.JobStore{}) mock.On("License").Return(&mocks.LicenseStore{}) mock.On("LinkMetadata").Return(&mocks.LinkMetadataStore{}) + mock.On("SharedChannel").Return(&mocks.SharedChannelStore{}) mock.On("OAuth").Return(&mocks.OAuthStore{}) mock.On("Plugin").Return(&mocks.PluginStore{}) mock.On("Post").Return(&mocks.PostStore{}) diff --git a/store/searchlayer/channel_layer.go b/store/searchlayer/channel_layer.go index 85130b6f92e..e5a2339eef0 100644 --- a/store/searchlayer/channel_layer.go +++ b/store/searchlayer/channel_layer.go @@ -114,8 +114,8 @@ func (c *SearchChannelStore) RemoveMembers(channelId string, userIds []string) e return nil } -func (c *SearchChannelStore) CreateDirectChannel(user *model.User, otherUser *model.User) (*model.Channel, error) { - channel, err := c.ChannelStore.CreateDirectChannel(user, otherUser) +func (c *SearchChannelStore) CreateDirectChannel(user *model.User, otherUser *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { + channel, err := c.ChannelStore.CreateDirectChannel(user, otherUser, channelOptions...) if err == nil { c.rootStore.indexUserFromID(user.Id) c.rootStore.indexUserFromID(otherUser.Id) diff --git a/store/searchlayer/file_info_layer.go b/store/searchlayer/file_info_layer.go index 6b18fb193ac..eb3c31bcc42 100644 --- a/store/searchlayer/file_info_layer.go +++ b/store/searchlayer/file_info_layer.go @@ -22,7 +22,7 @@ func (s SearchFileInfoStore) indexFile(file *model.FileInfo) { if file.PostId == "" { return } - post, postErr := s.rootStore.Post().GetSingle(file.PostId) + post, postErr := s.rootStore.Post().GetSingle(file.PostId, false) if postErr != nil { mlog.Error("Couldn't get post for file for SearchEngine indexing.", mlog.String("post_id", file.PostId), mlog.String("search_engine", engineCopy.GetName()), mlog.String("file_info_id", file.Id), mlog.Err(postErr)) return diff --git a/store/sqlstore/channel_store.go b/store/sqlstore/channel_store.go index 550ba66c24a..c1ba3855151 100644 --- a/store/sqlstore/channel_store.go +++ b/store/sqlstore/channel_store.go @@ -567,14 +567,20 @@ func (s SqlChannelStore) Save(channel *model.Channel, maxChannelsPerTeam int64) return newChannel, err } -func (s SqlChannelStore) CreateDirectChannel(user *model.User, otherUser *model.User) (*model.Channel, error) { +func (s SqlChannelStore) CreateDirectChannel(user *model.User, otherUser *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { channel := new(model.Channel) + for _, option := range channelOptions { + option(channel) + } + channel.DisplayName = "" channel.Name = model.GetDMNameFromIds(otherUser.Id, user.Id) channel.Header = "" channel.Type = model.CHANNEL_DIRECT + channel.Shared = model.NewBool(user.IsRemote() || otherUser.IsRemote()) + channel.CreatorId = user.Id cm1 := &model.ChannelMember{ UserId: user.Id, @@ -592,13 +598,13 @@ func (s SqlChannelStore) CreateDirectChannel(user *model.User, otherUser *model. return s.SaveDirectChannel(channel, cm1, cm2) } -func (s SqlChannelStore) SaveDirectChannel(directchannel *model.Channel, member1 *model.ChannelMember, member2 *model.ChannelMember) (*model.Channel, error) { - if directchannel.DeleteAt != 0 { - return nil, store.NewErrInvalidInput("Channel", "DeleteAt", directchannel.DeleteAt) +func (s SqlChannelStore) SaveDirectChannel(directChannel *model.Channel, member1 *model.ChannelMember, member2 *model.ChannelMember) (*model.Channel, error) { + if directChannel.DeleteAt != 0 { + return nil, store.NewErrInvalidInput("Channel", "DeleteAt", directChannel.DeleteAt) } - if directchannel.Type != model.CHANNEL_DIRECT { - return nil, store.NewErrInvalidInput("Channel", "Type", directchannel.Type) + if directChannel.Type != model.CHANNEL_DIRECT { + return nil, store.NewErrInvalidInput("Channel", "Type", directChannel.Type) } transaction, err := s.GetMaster().Begin() @@ -607,8 +613,8 @@ func (s SqlChannelStore) SaveDirectChannel(directchannel *model.Channel, member1 } defer finalizeTransaction(transaction) - directchannel.TeamId = "" - newChannel, err := s.saveChannelT(transaction, directchannel, 0) + directChannel.TeamId = "" + newChannel, err := s.saveChannelT(transaction, directChannel, 0) if err != nil { return newChannel, err } @@ -635,7 +641,7 @@ func (s SqlChannelStore) SaveDirectChannel(directchannel *model.Channel, member1 } func (s SqlChannelStore) saveChannelT(transaction *gorp.Transaction, channel *model.Channel, maxChannelsPerTeam int64) (*model.Channel, error) { - if channel.Id != "" { + if channel.Id != "" && !channel.IsShared() { return nil, store.NewErrInvalidInput("Channel", "Id", channel.Id) } @@ -3363,3 +3369,53 @@ func (s SqlChannelStore) GroupSyncedChannelCount() (int64, error) { return count, nil } + +// SetShared sets the Shared flag true/false +func (s SqlChannelStore) SetShared(channelId string, shared bool) error { + squery, args, err := s.getQueryBuilder(). + Update("Channels"). + Set("Shared", shared). + Where(sq.Eq{"Id": channelId}). + ToSql() + if err != nil { + return errors.Wrap(err, "channel_set_shared_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return errors.Wrap(err, "failed to update `Shared` for Channels") + } + + count, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to determine rows affected") + } + if count == 0 { + return fmt.Errorf("id not found: %s", channelId) + } + return nil +} + +// GetTeamForChannel returns the team for a given channelID. +func (s SqlChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) { + nestedQ, nestedArgs, err := s.getQueryBuilder().Select("TeamId").From("Channels").Where(sq.Eq{"Id": channelID}).ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_team_for_channel_nested_tosql") + } + query, args, err := s.getQueryBuilder(). + Select("*"). + From("Teams").Where(sq.Expr("Id = ("+nestedQ+")", nestedArgs...)).ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_team_for_channel_tosql") + } + + team := model.Team{} + err = s.GetReplica().SelectOne(&team, query, args...) + if err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("Team", fmt.Sprintf("channel_id=%s", channelID)) + } + return nil, errors.Wrapf(err, "failed to find team with channel_id=%s", channelID) + } + return &team, nil +} diff --git a/store/sqlstore/file_info_store.go b/store/sqlstore/file_info_store.go index 811b3373470..14c66427aaf 100644 --- a/store/sqlstore/file_info_store.go +++ b/store/sqlstore/file_info_store.go @@ -53,6 +53,7 @@ func newSqlFileInfoStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterfac "FileInfo.HasPreviewImage", "FileInfo.MiniPreview", "Coalesce(FileInfo.Content, '') AS Content", + "Coalesce(FileInfo.RemoteId, '') AS RemoteId", } for _, db := range sqlStore.GetAllConns() { @@ -67,6 +68,7 @@ func newSqlFileInfoStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterfac table.ColMap("Content").SetMaxSize(0) table.ColMap("Extension").SetMaxSize(64) table.ColMap("MimeType").SetMaxSize(256) + table.ColMap("RemoteId").SetMaxSize(26) } return s diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index d105f70fd47..0f9855b1308 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -42,7 +42,7 @@ func (s *SqlPostStore) ClearCaches() { } func postSliceColumns() []string { - return []string{"Id", "CreateAt", "UpdateAt", "EditAt", "DeleteAt", "IsPinned", "UserId", "ChannelId", "RootId", "ParentId", "OriginalId", "Message", "Type", "Props", "Hashtags", "Filenames", "FileIds", "HasReactions"} + return []string{"Id", "CreateAt", "UpdateAt", "EditAt", "DeleteAt", "IsPinned", "UserId", "ChannelId", "RootId", "ParentId", "OriginalId", "Message", "Type", "Props", "Hashtags", "Filenames", "FileIds", "HasReactions", "RemoteId"} } func postToSlice(post *model.Post) []interface{} { @@ -65,6 +65,7 @@ func postToSlice(post *model.Post) []interface{} { model.ArrayToJson(post.Filenames), model.ArrayToJson(post.FileIds), post.HasReactions, + post.RemoteId, } } @@ -89,6 +90,7 @@ func newSqlPostStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterface) s table.ColMap("Props").SetMaxSize(8000) table.ColMap("Filenames").SetMaxSize(model.POST_FILENAMES_MAX_RUNES) table.ColMap("FileIds").SetMaxSize(300) + table.ColMap("RemoteId").SetMaxSize(26) } return s @@ -117,7 +119,7 @@ func (s *SqlPostStore) SaveMultiple(posts []*model.Post) ([]*model.Post, int, er rootIds := make(map[string]int) maxDateRootIds := make(map[string]int64) for idx, post := range posts { - if post.Id != "" { + if post.Id != "" && !post.IsRemote() { return nil, idx, store.NewErrInvalidInput("Post", "id", post.Id) } post.PreSave() @@ -211,7 +213,7 @@ func (s *SqlPostStore) SaveMultiple(posts []*model.Post) ([]*model.Post, int, er } } - unknownRepliesPosts := []*model.Post{} + var unknownRepliesPosts []*model.Post for _, post := range posts { if post.RootId == "" { count, ok := rootIds[post.Id] @@ -521,9 +523,23 @@ func (s *SqlPostStore) Get(ctx context.Context, id string, skipFetchThreads, col return pl, nil } -func (s *SqlPostStore) GetSingle(id string) (*model.Post, error) { +func (s *SqlPostStore) GetSingle(id string, inclDeleted bool) (*model.Post, error) { + query := s.getQueryBuilder(). + Select("*"). + From("Posts"). + Where(sq.Eq{"Id": id}) + + if !inclDeleted { + query = query.Where(sq.Eq{"DeleteAt": 0}) + } + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "getsingleincldeleted_tosql") + } + var post model.Post - err := s.GetReplica().SelectOne(&post, "SELECT * FROM Posts WHERE Id = :Id AND DeleteAt = 0", map[string]interface{}{"Id": id}) + err = s.GetReplica().SelectOne(&post, queryString, args...) if err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Post", id) @@ -869,6 +885,11 @@ func (s *SqlPostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFr var posts []*model.Post + order := "DESC" + if options.SortAscending { + order = "ASC" + } + replyCountQuery1 := "" replyCountQuery2 := "" if options.SkipFetchThreads { @@ -905,7 +926,7 @@ func (s *SqlPostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFr AND ChannelId = :ChannelId LIMIT 1000) temp_tab)) ) j ON p1.Id = j.Id - ORDER BY CreateAt DESC` + ORDER BY CreateAt ` + order } else if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { query = `WITH cte AS (SELECT * @@ -917,7 +938,7 @@ func (s *SqlPostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFr (SELECT *` + replyCountQuery2 + ` FROM cte) UNION (SELECT *` + replyCountQuery1 + ` FROM Posts p1 WHERE id in (SELECT rootid FROM cte)) - ORDER BY CreateAt DESC` + ORDER BY CreateAt ` + order } _, err := s.GetReplica().Select(&posts, query, map[string]interface{}{"ChannelId": options.ChannelId, "Time": options.Time}) @@ -937,6 +958,56 @@ func (s *SqlPostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFr return list, nil } +func (s *SqlPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, _ /* allowFromCache */ bool) ([]*model.Post, error) { + if options.Limit < 0 || options.Limit > 1000 { + return nil, store.NewErrInvalidInput("Post", "", options.Limit) + } + + order := " ASC" + if options.SortDescending { + order = " DESC" + } + + query := s.getQueryBuilder(). + Select("*"). + From("Posts"). + Where(sq.GtOrEq{"UpdateAt": options.Since}). + Where(sq.Eq{"ChannelId": options.ChannelId}). + Limit(uint64(options.Limit)). + OrderBy("CreateAt"+order, "DeleteAt", "Id") + + if options.Until > 0 { + query = query.Where(sq.LtOrEq{"UpdateAt": options.Until}) + } + + if !options.IncludeDeleted { + query = query.Where(sq.Eq{"DeleteAt": 0}) + } + + if options.ExcludeRemoteId != "" { + query = query.Where(sq.NotEq{"COALESCE(Posts.RemoteId,'')": options.ExcludeRemoteId}) + } + + if options.Offset > 0 { + query = query.Offset(uint64(options.Offset)) + } + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "getpostssinceforsync_tosql") + } + + var posts []*model.Post + + _, err = s.GetReplica().Select(&posts, queryString, args...) + + if err != nil { + return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) + } + + return posts, nil +} + func (s *SqlPostStore) GetPostsBefore(options model.GetPostsOptions) (*model.PostList, error) { return s.getPostsAround(true, options) } diff --git a/store/sqlstore/reaction_store.go b/store/sqlstore/reaction_store.go index 0acf2979aa9..6ef1ea2db37 100644 --- a/store/sqlstore/reaction_store.go +++ b/store/sqlstore/reaction_store.go @@ -4,6 +4,8 @@ package sqlstore import ( + sq "github.com/Masterminds/squirrel" + "github.com/mattermost/mattermost-server/v5/model" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/store" @@ -24,6 +26,7 @@ func newSqlReactionStore(sqlStore *SqlStore) store.ReactionStore { table.ColMap("UserId").SetMaxSize(26) table.ColMap("PostId").SetMaxSize(26) table.ColMap("EmojiName").SetMaxSize(64) + table.ColMap("RemoteId").SetMaxSize(26) } return s @@ -75,26 +78,56 @@ func (s *SqlReactionStore) Delete(reaction *model.Reaction) (*model.Reaction, er return reaction, nil } +// GetForPost returns all reactions associated with `postId` that are not deleted. func (s *SqlReactionStore) GetForPost(postId string, allowFromCache bool) ([]*model.Reaction, error) { - var reactions []*model.Reaction + queryString, args, err := s.getQueryBuilder(). + Select("UserId", "PostId", "EmojiName", "CreateAt", "COALESCE(UpdateAt, CreateAt) As UpdateAt", + "COALESCE(DeleteAt, 0) As DeleteAt", "RemoteId"). + From("Reactions"). + Where(sq.Eq{"PostId": postId}). + Where(sq.Eq{"COALESCE(DeleteAt, 0)": 0}). + OrderBy("CreateAt"). + ToSql() - if _, err := s.GetReplica().Select(&reactions, - `SELECT - UserId, - PostId, - EmojiName, - CreateAt, - COALESCE(UpdateAt, CreateAt) As UpdateAt, - COALESCE(DeleteAt, 0) As DeleteAt - FROM - Reactions - WHERE - PostId = :PostId AND COALESCE(DeleteAt, 0) = 0 - ORDER BY - CreateAt`, map[string]interface{}{"PostId": postId}); err != nil { - return nil, errors.Wrapf(err, "failed to get Reactions with postId=%s", postId) + if err != nil { + return nil, errors.Wrap(err, "reactions_getforpost_tosql") } + var reactions []*model.Reaction + if _, err := s.GetReplica().Select(&reactions, queryString, args...); err != nil { + return nil, errors.Wrapf(err, "failed to get Reactions with postId=%s", postId) + } + return reactions, nil +} + +// GetForPostSince returns all reactions associated with `postId` updated after `since`. +func (s *SqlReactionStore) GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) { + query := s.getQueryBuilder(). + Select("UserId", "PostId", "EmojiName", "CreateAt", "COALESCE(UpdateAt, CreateAt) As UpdateAt", + "COALESCE(DeleteAt, 0) As DeleteAt", "RemoteId"). + From("Reactions"). + Where(sq.Eq{"PostId": postId}). + Where(sq.Gt{"UpdateAt": since}) + + if excludeRemoteId != "" { + query = query.Where(sq.NotEq{"COALESCE(RemoteId, '')": excludeRemoteId}) + } + + if !inclDeleted { + query = query.Where(sq.Eq{"COALESCE(DeleteAt, 0)": 0}) + } + + query.OrderBy("CreateAt") + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "reactions_getforpostsince_tosql") + } + + var reactions []*model.Reaction + if _, err := s.GetReplica().Select(&reactions, queryString, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find reactions") + } return reactions, nil } @@ -109,7 +142,8 @@ func (s *SqlReactionStore) BulkGetForPosts(postIds []string) ([]*model.Reaction, EmojiName, CreateAt, COALESCE(UpdateAt, CreateAt) As UpdateAt, - COALESCE(DeleteAt, 0) As DeleteAt + COALESCE(DeleteAt, 0) As DeleteAt, + RemoteId FROM Reactions WHERE @@ -133,16 +167,17 @@ func (s *SqlReactionStore) DeleteAllWithEmojiName(emojiName string) error { if _, err := s.GetReplica().Select(&reactions, `SELECT - UserId, - PostId, - EmojiName, - CreateAt, - COALESCE(UpdateAt, CreateAt) As UpdateAt, - COALESCE(DeleteAt, 0) As DeleteAt - FROM - Reactions - WHERE - EmojiName = :EmojiName AND COALESCE(DeleteAt, 0) = 0`, params); err != nil { + UserId, + PostId, + EmojiName, + CreateAt, + COALESCE(UpdateAt, CreateAt) As UpdateAt, + COALESCE(DeleteAt, 0) As DeleteAt, + RemoteId + FROM + Reactions + WHERE + EmojiName = :EmojiName AND COALESCE(DeleteAt, 0) = 0`, params); err != nil { return errors.Wrapf(err, "failed to get Reactions with emojiName=%s", emojiName) } @@ -201,28 +236,29 @@ func (s *SqlReactionStore) saveReactionAndUpdatePost(transaction *gorp.Transacti "EmojiName": reaction.EmojiName, "CreateAt": reaction.CreateAt, "UpdateAt": reaction.UpdateAt, + "RemoteId": reaction.RemoteId, } if s.DriverName() == model.DATABASE_DRIVER_MYSQL { if _, err := transaction.Exec( `INSERT INTO Reactions - (UserId, PostId, EmojiName, CreateAt, UpdateAt, DeleteAt) + (UserId, PostId, EmojiName, CreateAt, UpdateAt, DeleteAt, RemoteId) VALUES - (:UserId, :PostId, :EmojiName, :CreateAt, :UpdateAt, 0) + (:UserId, :PostId, :EmojiName, :CreateAt, :UpdateAt, 0, :RemoteId) ON DUPLICATE KEY UPDATE - UpdateAt = :UpdateAt, DeleteAt = 0`, params); err != nil { + UpdateAt = :UpdateAt, DeleteAt = 0, RemoteId = :RemoteId`, params); err != nil { return err } } else if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { if _, err := transaction.Exec( `INSERT INTO Reactions - (UserId, PostId, EmojiName, CreateAt, UpdateAt, DeleteAt) + (UserId, PostId, EmojiName, CreateAt, UpdateAt, DeleteAt, RemoteId) VALUES - (:UserId, :PostId, :EmojiName, :CreateAt, :UpdateAt, 0) + (:UserId, :PostId, :EmojiName, :CreateAt, :UpdateAt, 0, :RemoteId) ON CONFLICT (UserId, PostId, EmojiName) - DO UPDATE SET UpdateAt = :UpdateAt, DeleteAt = 0`, params); err != nil { + DO UPDATE SET UpdateAt = :UpdateAt, DeleteAt = 0, RemoteId = :RemoteId`, params); err != nil { return err } } @@ -237,13 +273,14 @@ func deleteReactionAndUpdatePost(transaction *gorp.Transaction, reaction *model. "CreateAt": reaction.CreateAt, "UpdateAt": reaction.UpdateAt, "DeleteAt": reaction.UpdateAt, // DeleteAt = UpdateAt + "RemoteId": reaction.RemoteId, } if _, err := transaction.Exec( `UPDATE Reactions SET - UpdateAt = :UpdateAt, DeleteAt = :DeleteAt + UpdateAt = :UpdateAt, DeleteAt = :DeleteAt, RemoteId = :RemoteId WHERE PostId = :PostId AND UserId = :UserId AND diff --git a/store/sqlstore/remote_cluster_store.go b/store/sqlstore/remote_cluster_store.go new file mode 100644 index 00000000000..62bba253afa --- /dev/null +++ b/store/sqlstore/remote_cluster_store.go @@ -0,0 +1,186 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "fmt" + "strings" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/store" +) + +type sqlRemoteClusterStore struct { + *SqlStore +} + +func newSqlRemoteClusterStore(sqlStore *SqlStore) store.RemoteClusterStore { + s := &sqlRemoteClusterStore{sqlStore} + + for _, db := range sqlStore.GetAllConns() { + table := db.AddTableWithName(model.RemoteCluster{}, "RemoteClusters").SetKeys(false, "RemoteId") + table.ColMap("RemoteId").SetMaxSize(26) + table.ColMap("RemoteTeamId").SetMaxSize(26) + table.ColMap("DisplayName").SetMaxSize(64) + table.ColMap("SiteURL").SetMaxSize(512) + table.ColMap("Token").SetMaxSize(26) + table.ColMap("RemoteToken").SetMaxSize(26) + table.ColMap("Topics").SetMaxSize(512) + table.ColMap("CreatorId").SetMaxSize(26) + } + return s +} + +func (s sqlRemoteClusterStore) Save(remoteCluster *model.RemoteCluster) (*model.RemoteCluster, error) { + remoteCluster.PreSave() + if err := remoteCluster.IsValid(); err != nil { + return nil, err + } + + if err := s.GetMaster().Insert(remoteCluster); err != nil { + return nil, errors.Wrap(err, "failed to save RemoteCluster") + } + return remoteCluster, nil +} + +func (s sqlRemoteClusterStore) Update(remoteCluster *model.RemoteCluster) (*model.RemoteCluster, error) { + remoteCluster.PreUpdate() + if err := remoteCluster.IsValid(); err != nil { + return nil, err + } + + if _, err := s.GetMaster().Update(remoteCluster); err != nil { + return nil, errors.Wrap(err, "failed to update RemoteCluster") + } + return remoteCluster, nil +} + +func (s sqlRemoteClusterStore) Delete(remoteId string) (bool, error) { + squery, args, err := s.getQueryBuilder(). + Delete("RemoteClusters"). + Where(sq.Eq{"RemoteId": remoteId}). + ToSql() + if err != nil { + return false, errors.Wrap(err, "delete_remote_cluster_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return false, errors.Wrap(err, "failed to delete RemoteCluster") + } + + count, err := result.RowsAffected() + if err != nil { + return false, errors.Wrap(err, "failed to determine rows affected") + } + + return count > 0, nil +} + +func (s sqlRemoteClusterStore) Get(remoteId string) (*model.RemoteCluster, error) { + query := s.getQueryBuilder(). + Select("*"). + From("RemoteClusters"). + Where(sq.Eq{"RemoteId": remoteId}) + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "remote_cluster_get_tosql") + } + + var rc model.RemoteCluster + if err := s.GetReplica().SelectOne(&rc, queryString, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find RemoteCluster") + } + return &rc, nil +} + +func (s sqlRemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) { + query := s.getQueryBuilder(). + Select("rc.*"). + From("RemoteClusters rc") + + if filter.InChannel != "" { + query = query.Where("rc.RemoteId IN (SELECT scr.RemoteId FROM SharedChannelRemotes scr WHERE scr.ChannelId = ?)", filter.InChannel) + } + + if filter.NotInChannel != "" { + query = query.Where("rc.RemoteId NOT IN (SELECT scr.RemoteId FROM SharedChannelRemotes scr WHERE scr.ChannelId = ?)", filter.NotInChannel) + } + + if filter.ExcludeOffline { + query = query.Where(sq.Gt{"rc.LastPingAt": model.GetMillis() - model.RemoteOfflineAfterMillis}) + } + + if filter.CreatorId != "" { + query = query.Where(sq.Eq{"rc.CreatorId": filter.CreatorId}) + } + + if filter.OnlyConfirmed { + query = query.Where(sq.NotEq{"rc.SiteURL": ""}) + } + + if filter.Topic != "" { + trimmed := strings.TrimSpace(filter.Topic) + if trimmed == "" || trimmed == "*" { + return nil, errors.New("invalid topic") + } + queryTopic := fmt.Sprintf("%% %s %%", trimmed) + query = query.Where(sq.Or{sq.Like{"rc.Topics": queryTopic}, sq.Eq{"rc.Topics": "*"}}) + } + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "remote_cluster_getall_tosql") + } + + var list []*model.RemoteCluster + if _, err := s.GetReplica().Select(&list, queryString, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find RemoteClusters") + } + return list, nil +} + +func (s sqlRemoteClusterStore) UpdateTopics(remoteClusterid string, topics string) (*model.RemoteCluster, error) { + rc, err := s.Get(remoteClusterid) + if err != nil { + return nil, err + } + rc.Topics = topics + + rc.PreUpdate() + + if _, err = s.GetMaster().Update(rc); err != nil { + return nil, err + } + return rc, nil +} + +func (s sqlRemoteClusterStore) SetLastPingAt(remoteClusterId string) error { + query := s.getQueryBuilder(). + Update("RemoteClusters"). + Set("LastPingAt", model.GetMillis()). + Where(sq.Eq{"RemoteId": remoteClusterId}) + + queryString, args, err := query.ToSql() + if err != nil { + return errors.Wrap(err, "remote_cluster_tosql") + } + + if _, err := s.GetMaster().Exec(queryString, args...); err != nil { + return errors.Wrap(err, "failed to update RemoteCluster") + } + return nil +} + +func (s *sqlRemoteClusterStore) createIndexesIfNotExists() { + uniquenessColumns := []string{"SiteUrl", "RemoteTeamId"} + if s.DriverName() == model.DATABASE_DRIVER_MYSQL { + uniquenessColumns = []string{"RemoteTeamId", "SiteUrl(168)"} + } + s.CreateUniqueCompositeIndexIfNotExists(RemoteClusterSiteURLUniqueIndex, "RemoteClusters", uniquenessColumns) +} diff --git a/store/sqlstore/remote_cluster_store_test.go b/store/sqlstore/remote_cluster_store_test.go new file mode 100644 index 00000000000..a6330b74a82 --- /dev/null +++ b/store/sqlstore/remote_cluster_store_test.go @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "testing" + + "github.com/mattermost/mattermost-server/v5/store/storetest" +) + +func TestRemoteClusterStore(t *testing.T) { + StoreTest(t, storetest.TestRemoteClusterStore) +} diff --git a/store/sqlstore/shared_channel_store.go b/store/sqlstore/shared_channel_store.go new file mode 100644 index 00000000000..3d93c80a8f0 --- /dev/null +++ b/store/sqlstore/shared_channel_store.go @@ -0,0 +1,712 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "database/sql" + "fmt" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/store" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" +) + +type SqlSharedChannelStore struct { + *SqlStore +} + +func newSqlSharedChannelStore(sqlStore *SqlStore) store.SharedChannelStore { + s := &SqlSharedChannelStore{ + SqlStore: sqlStore, + } + + for _, db := range sqlStore.GetAllConns() { + tableSharedChannels := db.AddTableWithName(model.SharedChannel{}, "SharedChannels").SetKeys(false, "ChannelId") + tableSharedChannels.ColMap("ChannelId").SetMaxSize(26) + tableSharedChannels.ColMap("TeamId").SetMaxSize(26) + tableSharedChannels.ColMap("CreatorId").SetMaxSize(26) + tableSharedChannels.ColMap("ShareName").SetMaxSize(64) + tableSharedChannels.SetUniqueTogether("ShareName", "TeamId") + tableSharedChannels.ColMap("ShareDisplayName").SetMaxSize(64) + tableSharedChannels.ColMap("SharePurpose").SetMaxSize(250) + tableSharedChannels.ColMap("ShareHeader").SetMaxSize(1024) + tableSharedChannels.ColMap("RemoteId").SetMaxSize(26) + + tableSharedChannelRemotes := db.AddTableWithName(model.SharedChannelRemote{}, "SharedChannelRemotes").SetKeys(false, "Id", "ChannelId") + tableSharedChannelRemotes.ColMap("Id").SetMaxSize(26) + tableSharedChannelRemotes.ColMap("ChannelId").SetMaxSize(26) + tableSharedChannelRemotes.ColMap("Description").SetMaxSize(64) + tableSharedChannelRemotes.ColMap("CreatorId").SetMaxSize(26) + tableSharedChannelRemotes.ColMap("RemoteId").SetMaxSize(26) + tableSharedChannelRemotes.SetUniqueTogether("ChannelId", "RemoteId") + + tableSharedChannelUsers := db.AddTableWithName(model.SharedChannelUser{}, "SharedChannelUsers").SetKeys(false, "Id") + tableSharedChannelUsers.ColMap("Id").SetMaxSize(26) + tableSharedChannelUsers.ColMap("UserId").SetMaxSize(26) + tableSharedChannelUsers.ColMap("RemoteId").SetMaxSize(26) + tableSharedChannelUsers.SetUniqueTogether("UserId", "RemoteId") + + tableSharedChannelFiles := db.AddTableWithName(model.SharedChannelAttachment{}, "SharedChannelAttachments").SetKeys(false, "Id") + tableSharedChannelFiles.ColMap("Id").SetMaxSize(26) + tableSharedChannelFiles.ColMap("FileId").SetMaxSize(26) + tableSharedChannelFiles.ColMap("RemoteId").SetMaxSize(26) + tableSharedChannelFiles.SetUniqueTogether("FileId", "RemoteId") + } + + return s +} + +func (s SqlSharedChannelStore) createIndexesIfNotExists() { + s.CreateIndexIfNotExists("idx_sharedchannelusers_user_id", "SharedChannelUsers", "UserId") + s.CreateIndexIfNotExists("idx_sharedchannelusers_remote_id", "SharedChannelUsers", "RemoteId") +} + +// Save inserts a new shared channel record. +func (s SqlSharedChannelStore) Save(sc *model.SharedChannel) (*model.SharedChannel, error) { + sc.PreSave() + if err := sc.IsValid(); err != nil { + return nil, err + } + + // make sure the shared channel is associated with a real channel. + channel, err := s.stores.channel.Get(sc.ChannelId, true) + if err != nil { + return nil, fmt.Errorf("invalid channel: %w", err) + } + + transaction, err := s.GetMaster().Begin() + if err != nil { + return nil, errors.Wrap(err, "begin_transaction") + } + defer finalizeTransaction(transaction) + + if err := transaction.Insert(sc); err != nil { + return nil, errors.Wrapf(err, "save_shared_channel: ChannelId=%s", sc.ChannelId) + } + + // set `Shared` flag in Channels table if needed + if channel.Shared == nil || !*channel.Shared { + if err := s.stores.channel.SetShared(channel.Id, true); err != nil { + return nil, err + } + } + + if err := transaction.Commit(); err != nil { + return nil, errors.Wrap(err, "commit_transaction") + } + return sc, nil +} + +// Get fetches a shared channel by channel_id. +func (s SqlSharedChannelStore) Get(channelId string) (*model.SharedChannel, error) { + var sc model.SharedChannel + + query := s.getQueryBuilder(). + Select("*"). + From("SharedChannels"). + Where(sq.Eq{"SharedChannels.ChannelId": channelId}) + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "getsharedchannel_tosql") + } + + if err := s.GetReplica().SelectOne(&sc, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannel", channelId) + } + return nil, errors.Wrapf(err, "failed to find shared channel with ChannelId=%s", channelId) + } + return &sc, nil +} + +// HasChannel returns whether a given channelID is a shared channel or not. +func (s SqlSharedChannelStore) HasChannel(channelID string) (bool, error) { + builder := s.getQueryBuilder(). + Select("1"). + Prefix("SELECT EXISTS ("). + From("SharedChannels"). + Where(sq.Eq{"SharedChannels.ChannelId": channelID}). + Suffix(")") + + query, args, err := builder.ToSql() + if err != nil { + return false, errors.Wrapf(err, "get_shared_channel_exists_tosql") + } + + var exists bool + if err := s.GetReplica().SelectOne(&exists, query, args...); err != nil { + return exists, errors.Wrapf(err, "failed to get shared channel for channel_id=%s", channelID) + } + return exists, nil +} + +// GetAll fetches a paginated list of shared channels filtered by SharedChannelSearchOpts. +func (s SqlSharedChannelStore) GetAll(offset, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) { + if opts.ExcludeHome && opts.ExcludeRemote { + return nil, errors.New("cannot exclude home and remote shared channels") + } + + safeConv := func(offset, limit int) (uint64, uint64, error) { + if offset < 0 { + return 0, 0, errors.New("offset must be positive integer") + } + if limit < 0 { + return 0, 0, errors.New("limit must be positive integer") + } + return uint64(offset), uint64(limit), nil + } + + safeOffset, safeLimit, err := safeConv(offset, limit) + if err != nil { + return nil, err + } + + query := s.getSharedChannelsQuery(opts, false) + query = query.OrderBy("sc.ShareDisplayName, sc.ShareName").Limit(safeLimit).Offset(safeOffset) + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "failed to create query") + } + + var channels []*model.SharedChannel + _, err = s.GetReplica().Select(&channels, squery, args...) + if err != nil { + return nil, errors.Wrap(err, "failed to get shared channels") + } + return channels, nil +} + +// GetAllCount returns the number of shared channels that would be fetched using SharedChannelSearchOpts. +func (s SqlSharedChannelStore) GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) { + if opts.ExcludeHome && opts.ExcludeRemote { + return 0, errors.New("cannot exclude home and remote shared channels") + } + + query := s.getSharedChannelsQuery(opts, true) + squery, args, err := query.ToSql() + if err != nil { + return 0, errors.Wrap(err, "failed to create query") + } + + count, err := s.GetReplica().SelectInt(squery, args...) + if err != nil { + return 0, errors.Wrap(err, "failed to count channels") + } + return count, nil +} + +func (s SqlSharedChannelStore) getSharedChannelsQuery(opts model.SharedChannelFilterOpts, forCount bool) sq.SelectBuilder { + var selectStr string + if forCount { + selectStr = "count(sc.ChannelId)" + } else { + selectStr = "sc.*" + } + + query := s.getQueryBuilder(). + Select(selectStr). + From("SharedChannels AS sc") + + if opts.TeamId != "" { + query = query.Where(sq.Eq{"sc.TeamId": opts.TeamId}) + } + + if opts.CreatorId != "" { + query = query.Where(sq.Eq{"sc.CreatorId": opts.CreatorId}) + } + + if opts.ExcludeHome { + query = query.Where(sq.NotEq{"sc.Home": true}) + } + + if opts.ExcludeRemote { + query = query.Where(sq.Eq{"sc.Home": true}) + } + + return query +} + +// Update updates the shared channel. +func (s SqlSharedChannelStore) Update(sc *model.SharedChannel) (*model.SharedChannel, error) { + if err := sc.IsValid(); err != nil { + return nil, err + } + + count, err := s.GetMaster().Update(sc) + if err != nil { + return nil, errors.Wrapf(err, "failed to update shared channel with channelId=%s", sc.ChannelId) + } + + if count != 1 { + return nil, fmt.Errorf("expected number of shared channels to be updated is 1 but was %d", count) + } + return sc, nil +} + +// Delete deletes a single shared channel plus associated SharedChannelRemotes. +// Returns true if shared channel found and deleted, false if not found. +func (s SqlSharedChannelStore) Delete(channelId string) (bool, error) { + transaction, err := s.GetMaster().Begin() + if err != nil { + return false, errors.Wrap(err, "DeleteSharedChannel: begin_transaction") + } + defer finalizeTransaction(transaction) + + squery, args, err := s.getQueryBuilder(). + Delete("SharedChannels"). + Where(sq.Eq{"SharedChannels.ChannelId": channelId}). + ToSql() + if err != nil { + return false, errors.Wrap(err, "delete_shared_channel_tosql") + } + + result, err := transaction.Exec(squery, args...) + if err != nil { + return false, errors.Wrap(err, "failed to delete SharedChannel") + } + + // Also remove remotes from SharedChannelRemotes (if any). + squery, args, err = s.getQueryBuilder(). + Delete("SharedChannelRemotes"). + Where(sq.Eq{"ChannelId": channelId}). + ToSql() + if err != nil { + return false, errors.Wrap(err, "delete_shared_channel_remotes_tosql") + } + + _, err = transaction.Exec(squery, args...) + if err != nil { + return false, errors.Wrap(err, "failed to delete SharedChannelRemotes") + } + + count, err := result.RowsAffected() + if err != nil { + return false, errors.Wrap(err, "failed to determine rows affected") + } + + if count > 0 { + // unset the channel's Shared flag + if err = s.Channel().SetShared(channelId, false); err != nil { + return false, errors.Wrap(err, "error unsetting channel share flag") + } + } + + if err = transaction.Commit(); err != nil { + return false, errors.Wrap(err, "commit_transaction") + } + + return count > 0, nil +} + +// SaveRemote inserts a new shared channel remote record. +func (s SqlSharedChannelStore) SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + remote.PreSave() + if err := remote.IsValid(); err != nil { + return nil, err + } + + // make sure the shared channel remote is associated with a real channel. + if _, err := s.stores.channel.Get(remote.ChannelId, true); err != nil { + return nil, fmt.Errorf("invalid channel: %w", err) + } + + if err := s.GetMaster().Insert(remote); err != nil { + return nil, errors.Wrapf(err, "save_shared_channel_remote: channel_id=%s, id=%s", remote.ChannelId, remote.Id) + } + return remote, nil +} + +// Update updates the shared channel remote. +func (s SqlSharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + if err := remote.IsValid(); err != nil { + return nil, err + } + + count, err := s.GetMaster().Update(remote) + if err != nil { + return nil, errors.Wrapf(err, "failed to update shared channel remote with remoteId=%s", remote.Id) + } + + if count != 1 { + return nil, fmt.Errorf("expected number of shared channel remotes to be updated is 1 but was %d", count) + } + return remote, nil +} + +// GetRemote fetches a shared channel remote by id. +func (s SqlSharedChannelStore) GetRemote(id string) (*model.SharedChannelRemote, error) { + var remote model.SharedChannelRemote + + query := s.getQueryBuilder(). + Select("*"). + From("SharedChannelRemotes"). + Where(sq.Eq{"SharedChannelRemotes.Id": id}) + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "get_shared_channel_remote_tosql") + } + + if err := s.GetReplica().SelectOne(&remote, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannelRemote", id) + } + return nil, errors.Wrapf(err, "failed to find shared channel remote with id=%s", id) + } + return &remote, nil +} + +// GetRemoteByIds fetches a shared channel remote by channel id and remote cluster id. +func (s SqlSharedChannelStore) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) { + var remote model.SharedChannelRemote + + query := s.getQueryBuilder(). + Select("*"). + From("SharedChannelRemotes"). + Where(sq.Eq{"SharedChannelRemotes.ChannelId": channelId}). + Where(sq.Eq{"SharedChannelRemotes.RemoteId": remoteId}) + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "get_shared_channel_remote_by_ids_tosql") + } + + if err := s.GetReplica().SelectOne(&remote, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannelRemote", fmt.Sprintf("channelId=%s, remoteId=%s", channelId, remoteId)) + } + return nil, errors.Wrapf(err, "failed to find shared channel remote with channelId=%s, remoteId=%s", channelId, remoteId) + } + return &remote, nil +} + +// GetRemotes fetches all shared channel remotes associated with channel_id. +func (s SqlSharedChannelStore) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + var remotes []*model.SharedChannelRemote + + query := s.getQueryBuilder(). + Select("*"). + From("SharedChannelRemotes") + + if opts.ChannelId != "" { + query = query.Where(sq.Eq{"ChannelId": opts.ChannelId}) + } + + if opts.RemoteId != "" { + query = query.Where(sq.Eq{"RemoteId": opts.RemoteId}) + } + + if !opts.InclUnconfirmed { + query = query.Where(sq.Eq{"IsInviteConfirmed": true}) + } + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "get_shared_channel_remotes_tosql") + } + + if _, err := s.GetReplica().Select(&remotes, squery, args...); err != nil { + if err != sql.ErrNoRows { + return nil, errors.Wrapf(err, "failed to get shared channel remotes for channel_id=%s; remote_id=%s", + opts.ChannelId, opts.RemoteId) + } + } + return remotes, nil +} + +// HasRemote returns whether a given remoteId and channelId are present in the shared channel remotes or not. +func (s SqlSharedChannelStore) HasRemote(channelID string, remoteId string) (bool, error) { + builder := s.getQueryBuilder(). + Select("1"). + Prefix("SELECT EXISTS ("). + From("SharedChannelRemotes"). + Where(sq.Eq{"RemoteId": remoteId}). + Where(sq.Eq{"ChannelId": channelID}). + Suffix(")") + + query, args, err := builder.ToSql() + if err != nil { + return false, errors.Wrapf(err, "get_shared_channel_hasremote_tosql") + } + + var hasRemote bool + if err := s.GetReplica().SelectOne(&hasRemote, query, args...); err != nil { + return hasRemote, errors.Wrapf(err, "failed to get channel remotes for channel_id=%s", channelID) + } + return hasRemote, nil +} + +// GetRemoteForUser returns a remote cluster for the given userId only if the user belongs to at least one channel +// shared with the remote. +func (s SqlSharedChannelStore) GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) { + builder := s.getQueryBuilder(). + Select("rc.*"). + From("RemoteClusters AS rc"). + Join("SharedChannelRemotes AS scr ON rc.RemoteId = scr.RemoteId"). + Join("ChannelMembers AS cm ON scr.ChannelId = cm.ChannelId"). + Where(sq.Eq{"rc.RemoteId": remoteId}). + Where(sq.Eq{"cm.UserId": userId}) + + query, args, err := builder.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "get_remote_for_user_tosql") + } + + var rc model.RemoteCluster + if err := s.GetReplica().SelectOne(&rc, query, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("RemoteCluster", remoteId) + } + return nil, errors.Wrapf(err, "failed to get remote for user_id=%s", userId) + } + return &rc, nil +} + +// UpdateRemoteNextSyncAt updates the NextSyncAt timestamp for the specified SharedChannelRemote. +func (s SqlSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { + squery, args, err := s.getQueryBuilder(). + Update("SharedChannelRemotes"). + Set("NextSyncAt", syncTime). + Where(sq.Eq{"Id": id}). + ToSql() + if err != nil { + return errors.Wrap(err, "update_shared_channel_remote_next_sync_at_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return errors.Wrap(err, "failed to update NextSyncAt for SharedChannelRemote") + } + + count, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to determine rows affected") + } + if count == 0 { + return fmt.Errorf("id not found: %s", id) + } + return nil +} + +// DeleteRemote deletes a single shared channel remote. +// Returns true if remote found and deleted, false if not found. +func (s SqlSharedChannelStore) DeleteRemote(id string) (bool, error) { + squery, args, err := s.getQueryBuilder(). + Delete("SharedChannelRemotes"). + Where(sq.Eq{"Id": id}). + ToSql() + if err != nil { + return false, errors.Wrap(err, "delete_shared_channel_remote_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return false, errors.Wrap(err, "failed to delete SharedChannelRemote") + } + + count, err := result.RowsAffected() + if err != nil { + return false, errors.Wrap(err, "failed to determine rows affected") + } + + return count > 0, nil +} + +// GetRemotesStatus returns the status for each remote invited to the +// specified shared channel. +func (s SqlSharedChannelStore) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) { + var status []*model.SharedChannelRemoteStatus + + query := s.getQueryBuilder(). + Select("scr.ChannelId, rc.DisplayName, rc.SiteURL, rc.LastPingAt, scr.NextSyncAt, scr.Description, sc.ReadOnly, scr.IsInviteAccepted"). + From("SharedChannelRemotes scr, RemoteClusters rc, SharedChannels sc"). + Where("scr.RemoteId = rc.RemoteId"). + Where("scr.ChannelId = sc.ChannelId"). + Where(sq.Eq{"scr.ChannelId": channelId}) + + squery, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrapf(err, "get_shared_channel_remotes_status_tosql") + } + + if _, err := s.GetReplica().Select(&status, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannelRemoteStatus", channelId) + } + return nil, errors.Wrapf(err, "failed to get shared channel remote status for channel_id=%s", channelId) + } + return status, nil +} + +// SaveUser inserts a new shared channel user record to the SharedChannelUsers table. +func (s SqlSharedChannelStore) SaveUser(scUser *model.SharedChannelUser) (*model.SharedChannelUser, error) { + scUser.PreSave() + if err := scUser.IsValid(); err != nil { + return nil, err + } + + if err := s.GetMaster().Insert(scUser); err != nil { + return nil, errors.Wrapf(err, "save_shared_channel_user: user_id=%s, remote_id=%s", scUser.UserId, scUser.RemoteId) + } + return scUser, nil +} + +// GetUser fetches a shared channel user based on user_id and remoteId. +func (s SqlSharedChannelStore) GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) { + var scu model.SharedChannelUser + + squery, args, err := s.getQueryBuilder(). + Select("*"). + From("SharedChannelUsers"). + Where(sq.Eq{"SharedChannelUsers.UserId": userId}). + Where(sq.Eq{"SharedChannelUsers.RemoteId": remoteId}). + ToSql() + + if err != nil { + return nil, errors.Wrapf(err, "getsharedchanneluser_tosql") + } + + if err := s.GetReplica().SelectOne(&scu, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannelUser", userId) + } + return nil, errors.Wrapf(err, "failed to find shared channel user with UserId=%s, RemoteId=%s", userId, remoteId) + } + return &scu, nil +} + +// UpdateUserLastSyncAt updates the LastSyncAt timestamp for the specified SharedChannelUser. +func (s SqlSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { + squery, args, err := s.getQueryBuilder(). + Update("SharedChannelUsers"). + Set("LastSyncAt", syncTime). + Where(sq.Eq{"Id": id}). + ToSql() + if err != nil { + return errors.Wrap(err, "update_shared_channel_user_last_sync_at_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return errors.Wrap(err, "failed to update LastSycnAt for SharedChannelUser") + } + + count, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to determine rows affected") + } + if count == 0 { + return fmt.Errorf("id not found: %s", id) + } + return nil +} + +// SaveAttachment inserts a new shared channel file attachment record to the SharedChannelFiles table. +func (s SqlSharedChannelStore) SaveAttachment(attachment *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) { + attachment.PreSave() + if err := attachment.IsValid(); err != nil { + return nil, err + } + + if err := s.GetMaster().Insert(attachment); err != nil { + return nil, errors.Wrapf(err, "save_shared_channel_attachment: file_id=%s, remote_id=%s", attachment.FileId, attachment.RemoteId) + } + return attachment, nil +} + +// UpsertAttachment inserts a new shared channel file attachment record to the SharedChannelFiles table or updates its +// LastSyncAt. +func (s SqlSharedChannelStore) UpsertAttachment(attachment *model.SharedChannelAttachment) (string, error) { + attachment.PreSave() + if err := attachment.IsValid(); err != nil { + return "", err + } + + params := map[string]interface{}{ + "Id": attachment.Id, + "FileId": attachment.FileId, + "RemoteId": attachment.RemoteId, + "CreateAt": attachment.CreateAt, + "LastSyncAt": attachment.LastSyncAt, + } + + if s.DriverName() == model.DATABASE_DRIVER_MYSQL { + if _, err := s.GetMaster().Exec( + `INSERT INTO + SharedChannelAttachments + (Id, FileId, RemoteId, CreateAt, LastSyncAt) + VALUES + (:Id, :FileId, :RemoteId, :CreateAt, :LastSyncAt) + ON DUPLICATE KEY UPDATE + LastSyncAt = :LastSyncAt`, params); err != nil { + return "", err + } + } else if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { + if _, err := s.GetMaster().Exec( + `INSERT INTO + SharedChannelAttachments + (Id, FileId, RemoteId, CreateAt, LastSyncAt) + VALUES + (:Id, :FileId, :RemoteId, :CreateAt, :LastSyncAt) + ON CONFLICT (Id) + DO UPDATE SET LastSyncAt = :LastSyncAt`, params); err != nil { + return "", err + } + } + return attachment.Id, nil +} + +// GetAttachment fetches a shared channel file attachment record based on file_id and remoteId. +func (s SqlSharedChannelStore) GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) { + var attachment model.SharedChannelAttachment + + squery, args, err := s.getQueryBuilder(). + Select("*"). + From("SharedChannelAttachments"). + Where(sq.Eq{"SharedChannelAttachments.FileId": fileId}). + Where(sq.Eq{"SharedChannelAttachments.RemoteId": remoteId}). + ToSql() + + if err != nil { + return nil, errors.Wrapf(err, "getsharedchannelattachment_tosql") + } + + if err := s.GetReplica().SelectOne(&attachment, squery, args...); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("SharedChannelAttachment", fileId) + } + return nil, errors.Wrapf(err, "failed to find shared channel attachment with FileId=%s, RemoteId=%s", fileId, remoteId) + } + return &attachment, nil +} + +// UpdateAttachmentLastSyncAt updates the LastSyncAt timestamp for the specified SharedChannelAttachment. +func (s SqlSharedChannelStore) UpdateAttachmentLastSyncAt(id string, syncTime int64) error { + squery, args, err := s.getQueryBuilder(). + Update("SharedChannelAttachments"). + Set("LastSyncAt", syncTime). + Where(sq.Eq{"Id": id}). + ToSql() + if err != nil { + return errors.Wrap(err, "update_shared_channel_attachment_last_sync_at_tosql") + } + + result, err := s.GetMaster().Exec(squery, args...) + if err != nil { + return errors.Wrap(err, "failed to update LastSycnAt for SharedChannelAttachment") + } + + count, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to determine rows affected") + } + if count == 0 { + return fmt.Errorf("id not found: %s", id) + } + return nil +} diff --git a/store/sqlstore/shared_channel_store_test.go b/store/sqlstore/shared_channel_store_test.go new file mode 100644 index 00000000000..3f67e05b92d --- /dev/null +++ b/store/sqlstore/shared_channel_store_test.go @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "testing" + + "github.com/mattermost/mattermost-server/v5/store/storetest" +) + +func TestSharedChannelStore(t *testing.T) { + StoreTestWithSqlStore(t, storetest.TestSharedChannelStore) +} diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index 0340f410a18..2baa6e23eb2 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -101,6 +101,7 @@ type SqlStoreStores struct { bot store.BotStore audit store.AuditStore cluster store.ClusterDiscoveryStore + remoteCluster store.RemoteClusterStore compliance store.ComplianceStore session store.SessionStore oauth store.OAuthStore @@ -127,6 +128,7 @@ type SqlStoreStores struct { group store.GroupStore UserTermsOfService store.UserTermsOfServiceStore linkMetadata store.LinkMetadataStore + sharedchannel store.SharedChannelStore } type SqlStore struct { @@ -186,6 +188,7 @@ func New(settings model.SqlSettings, metrics einterfaces.MetricsInterface) *SqlS store.stores.bot = newSqlBotStore(store, metrics) store.stores.audit = newSqlAuditStore(store) store.stores.cluster = newSqlClusterDiscoveryStore(store) + store.stores.remoteCluster = newSqlRemoteClusterStore(store) store.stores.compliance = newSqlComplianceStore(store) store.stores.session = newSqlSessionStore(store) store.stores.oauth = newSqlOAuthStore(store) @@ -208,6 +211,7 @@ func New(settings model.SqlSettings, metrics einterfaces.MetricsInterface) *SqlS store.stores.TermsOfService = newSqlTermsOfServiceStore(store, metrics) store.stores.UserTermsOfService = newSqlUserTermsOfServiceStore(store) store.stores.linkMetadata = newSqlLinkMetadataStore(store) + store.stores.sharedchannel = newSqlSharedChannelStore(store) store.stores.reaction = newSqlReactionStore(store) store.stores.role = newSqlRoleStore(store) store.stores.scheme = newSqlSchemeStore(store) @@ -258,8 +262,10 @@ func New(settings model.SqlSettings, metrics einterfaces.MetricsInterface) *SqlS store.stores.productNotices.(SqlProductNoticesStore).createIndexesIfNotExists() store.stores.UserTermsOfService.(SqlUserTermsOfServiceStore).createIndexesIfNotExists() store.stores.linkMetadata.(*SqlLinkMetadataStore).createIndexesIfNotExists() + store.stores.sharedchannel.(*SqlSharedChannelStore).createIndexesIfNotExists() store.stores.group.(*SqlGroupStore).createIndexesIfNotExists() store.stores.scheme.(*SqlSchemeStore).createIndexesIfNotExists() + store.stores.remoteCluster.(*sqlRemoteClusterStore).createIndexesIfNotExists() store.stores.preference.(*SqlPreferenceStore).deleteUnusedFeatures() return store @@ -1212,6 +1218,10 @@ func (ss *SqlStore) ClusterDiscovery() store.ClusterDiscoveryStore { return ss.stores.cluster } +func (ss *SqlStore) RemoteCluster() store.RemoteClusterStore { + return ss.stores.remoteCluster +} + func (ss *SqlStore) Compliance() store.ComplianceStore { return ss.stores.compliance } @@ -1316,6 +1326,10 @@ func (ss *SqlStore) LinkMetadata() store.LinkMetadataStore { return ss.stores.linkMetadata } +func (ss *SqlStore) SharedChannel() store.SharedChannelStore { + return ss.stores.sharedchannel +} + func (ss *SqlStore) DropAllTables() { ss.master.TruncateTables() } diff --git a/store/sqlstore/upgrade.go b/store/sqlstore/upgrade.go index 2234168ccc1..ed283cf85dc 100644 --- a/store/sqlstore/upgrade.go +++ b/store/sqlstore/upgrade.go @@ -961,6 +961,8 @@ func upgradeDatabaseToVersion531(sqlStore *SqlStore) { } } +const RemoteClusterSiteURLUniqueIndex = "remote_clusters_site_url_unique" + func hasMissingMigrationsVersion532(sqlStore *SqlStore) bool { scIdInfo, err := sqlStore.GetColumnInfo("Posts", "FileIds") if err != nil { @@ -987,6 +989,7 @@ func upgradeDatabaseToVersion532(sqlStore *SqlStore) { } if shouldPerformUpgrade(sqlStore, Version5310, Version5320) { sqlStore.CreateColumnIfNotExists("ThreadMemberships", "UnreadMentions", "bigint", "bigint", "0") + // Shared channels support sqlStore.CreateColumnIfNotExistsNoDefault("Channels", "Shared", "tinyint(1)", "boolean") sqlStore.CreateColumnIfNotExistsNoDefault("Reactions", "UpdateAt", "bigint", "bigint") sqlStore.CreateColumnIfNotExistsNoDefault("Reactions", "DeleteAt", "bigint", "bigint") @@ -1011,6 +1014,24 @@ func upgradeDatabaseToVersion535(sqlStore *SqlStore) { sqlStore.CreateColumnIfNotExists("SidebarCategories", "Collapsed", "tinyint(1)", "boolean", "0") + // Shared channels support + sqlStore.CreateColumnIfNotExistsNoDefault("Reactions", "RemoteId", "VARCHAR(26)", "VARCHAR(26)") + sqlStore.CreateColumnIfNotExistsNoDefault("Users", "RemoteId", "VARCHAR(26)", "VARCHAR(26)") + sqlStore.CreateColumnIfNotExistsNoDefault("Posts", "RemoteId", "VARCHAR(26)", "VARCHAR(26)") + sqlStore.CreateColumnIfNotExistsNoDefault("FileInfo", "RemoteId", "VARCHAR(26)", "VARCHAR(26)") + sqlStore.CreateColumnIfNotExists("UploadSessions", "RemoteId", "VARCHAR(26)", "VARCHAR(26)", "") + sqlStore.CreateColumnIfNotExists("UploadSessions", "ReqFileId", "VARCHAR(26)", "VARCHAR(26)", "") + if _, err := sqlStore.GetMaster().Exec("UPDATE UploadSessions SET RemoteId='', ReqFileId='' WHERE RemoteId IS NULL"); err != nil { + mlog.Error("Error updating RemoteId,ReqFileId in UploadsSession table", mlog.Err(err)) + } + uniquenessColumns := []string{"SiteUrl", "RemoteTeamId"} + if sqlStore.DriverName() == model.DATABASE_DRIVER_MYSQL { + uniquenessColumns = []string{"RemoteTeamId", "SiteUrl(168)"} + } + sqlStore.CreateUniqueCompositeIndexIfNotExists(RemoteClusterSiteURLUniqueIndex, "RemoteClusters", uniquenessColumns) + + sqlStore.CreateColumnIfNotExistsNoDefault("Channels", "TotalMsgCountRoot", "bigint", "bigint") + // note: setting default 0 on pre-5.0 tables causes test-db-migration script to fail, so this column will be added to ignore list sqlStore.CreateColumnIfNotExists("ChannelMembers", "MentionCountRoot", "bigint", "bigint", "0") sqlStore.AlterColumnDefaultIfExists("ChannelMembers", "MentionCountRoot", model.NewString("0"), model.NewString("0")) diff --git a/store/sqlstore/upload_session_store.go b/store/sqlstore/upload_session_store.go index 12e2c0ebf1d..980de2abbde 100644 --- a/store/sqlstore/upload_session_store.go +++ b/store/sqlstore/upload_session_store.go @@ -29,6 +29,8 @@ func newSqlUploadSessionStore(sqlStore *SqlStore) store.UploadSessionStore { table.ColMap("ChannelId").SetMaxSize(26) table.ColMap("Filename").SetMaxSize(256) table.ColMap("Path").SetMaxSize(512) + table.ColMap("RemoteId").SetMaxSize(26) + table.ColMap("ReqFileId").SetMaxSize(26) } return s } diff --git a/store/sqlstore/user_store.go b/store/sqlstore/user_store.go index 2110ed45c57..c5d5724bd2f 100644 --- a/store/sqlstore/user_store.go +++ b/store/sqlstore/user_store.go @@ -53,7 +53,7 @@ func newSqlUserStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterface) s // note: we are providing field names explicitly here to maintain order of columns (needed when using raw queries) us.usersQuery = us.getQueryBuilder(). Select("u.Id", "u.CreateAt", "u.UpdateAt", "u.DeleteAt", "u.Username", "u.Password", "u.AuthData", "u.AuthService", "u.Email", "u.EmailVerified", "u.Nickname", "u.FirstName", "u.LastName", "u.Position", "u.Roles", "u.AllowMarketing", "u.Props", "u.NotifyProps", "u.LastPasswordUpdate", "u.LastPictureUpdate", "u.FailedAttempts", "u.Locale", "u.Timezone", "u.MfaActive", "u.MfaSecret", - "b.UserId IS NOT NULL AS IsBot", "COALESCE(b.Description, '') AS BotDescription", "COALESCE(b.LastIconUpdate, 0) AS BotLastIconUpdate"). + "b.UserId IS NOT NULL AS IsBot", "COALESCE(b.Description, '') AS BotDescription", "COALESCE(b.LastIconUpdate, 0) AS BotLastIconUpdate", "u.RemoteId"). From("Users u"). LeftJoin("Bots b ON ( b.UserId = u.Id )") @@ -73,6 +73,7 @@ func newSqlUserStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterface) s table.ColMap("NotifyProps").SetMaxSize(2000) table.ColMap("Locale").SetMaxSize(5) table.ColMap("MfaSecret").SetMaxSize(128) + table.ColMap("RemoteId").SetMaxSize(26) table.ColMap("Position").SetMaxSize(128) table.ColMap("Timezone").SetMaxSize(256) } @@ -101,7 +102,7 @@ func (us SqlUserStore) createIndexesIfNotExists() { } func (us SqlUserStore) Save(user *model.User) (*model.User, error) { - if user.Id != "" { + if user.Id != "" && !user.IsRemote() { return nil, store.NewErrInvalidInput("User", "id", user.Id) } @@ -358,7 +359,7 @@ func (us SqlUserStore) Get(ctx context.Context, id string) (*model.User, error) &user.Nickname, &user.FirstName, &user.LastName, &user.Position, &user.Roles, &user.AllowMarketing, &props, ¬ifyProps, &user.LastPasswordUpdate, &user.LastPictureUpdate, &user.FailedAttempts, &user.Locale, &timezone, &user.MfaActive, &user.MfaSecret, - &user.IsBot, &user.BotDescription, &user.BotLastIconUpdate) + &user.IsBot, &user.BotDescription, &user.BotLastIconUpdate, &user.RemoteId) if err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("User", id) @@ -727,7 +728,7 @@ func (us SqlUserStore) GetAllProfilesInChannel(ctx context.Context, channelID st for rows.Next() { var user model.User var props, notifyProps, timezone []byte - if err = rows.Scan(&user.Id, &user.CreateAt, &user.UpdateAt, &user.DeleteAt, &user.Username, &user.Password, &user.AuthData, &user.AuthService, &user.Email, &user.EmailVerified, &user.Nickname, &user.FirstName, &user.LastName, &user.Position, &user.Roles, &user.AllowMarketing, &props, ¬ifyProps, &user.LastPasswordUpdate, &user.LastPictureUpdate, &user.FailedAttempts, &user.Locale, &timezone, &user.MfaActive, &user.MfaSecret, &user.IsBot, &user.BotDescription, &user.BotLastIconUpdate); err != nil { + if err = rows.Scan(&user.Id, &user.CreateAt, &user.UpdateAt, &user.DeleteAt, &user.Username, &user.Password, &user.AuthData, &user.AuthService, &user.Email, &user.EmailVerified, &user.Nickname, &user.FirstName, &user.LastName, &user.Position, &user.Roles, &user.AllowMarketing, &props, ¬ifyProps, &user.LastPasswordUpdate, &user.LastPictureUpdate, &user.FailedAttempts, &user.Locale, &timezone, &user.MfaActive, &user.MfaSecret, &user.IsBot, &user.BotDescription, &user.BotLastIconUpdate, &user.RemoteId); err != nil { return nil, errors.Wrap(err, "failed to scan values from rows into User entity") } if err = json.Unmarshal(props, &user.Props); err != nil { diff --git a/store/store.go b/store/store.go index 56cd322211d..b9042003c40 100644 --- a/store/store.go +++ b/store/store.go @@ -28,6 +28,7 @@ type Store interface { Bot() BotStore Audit() AuditStore ClusterDiscovery() ClusterDiscoveryStore + RemoteCluster() RemoteClusterStore Compliance() ComplianceStore Session() SessionStore OAuth() OAuthStore @@ -54,6 +55,7 @@ type Store interface { Group() GroupStore UserTermsOfService() UserTermsOfServiceStore LinkMetadata() LinkMetadataStore + SharedChannel() SharedChannelStore MarkSystemRanUnitTests() Close() LockToMaster() @@ -135,7 +137,7 @@ type TeamStore interface { type ChannelStore interface { Save(channel *model.Channel, maxChannelsPerTeam int64) (*model.Channel, error) - CreateDirectChannel(userId *model.User, otherUserId *model.User) (*model.Channel, error) + CreateDirectChannel(userId *model.User, otherUserId *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) SaveDirectChannel(channel *model.Channel, member1 *model.ChannelMember, member2 *model.ChannelMember) (*model.Channel, error) Update(channel *model.Channel) (*model.Channel, error) UpdateSidebarChannelCategoryOnMove(channel *model.Channel, newTeamID string) error @@ -240,6 +242,10 @@ type ChannelStore interface { // GroupSyncedChannelCount returns the count of non-deleted group-constrained channels. GroupSyncedChannelCount() (int64, error) + + SetShared(channelId string, shared bool) error + // GetTeamForChannel returns the team for a given channelID. + GetTeamForChannel(channelID string) (*model.Team, error) } type ChannelMemberHistoryStore interface { @@ -276,7 +282,7 @@ type PostStore interface { Save(post *model.Post) (*model.Post, error) Update(newPost *model.Post, oldPost *model.Post) (*model.Post, error) Get(ctx context.Context, id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool, userID string) (*model.PostList, error) - GetSingle(id string) (*model.Post, error) + GetSingle(id string, inclDeleted bool) (*model.Post, error) Delete(postID string, time int64, deleteByID string) error PermanentDeleteByUser(userId string) error PermanentDeleteByChannel(channelID string) error @@ -311,6 +317,7 @@ type PostStore interface { GetDirectPostParentsForExportAfter(limit int, afterID string) ([]*model.DirectPostForExport, error) SearchPostsInTeamForUser(paramsList []*model.SearchParams, userId, teamID string, page, perPage int) (*model.PostSearchResults, error) GetOldestEntityCreationTime() (int64, error) + GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) } type UserStore interface { @@ -427,6 +434,16 @@ type ClusterDiscoveryStore interface { Cleanup() error } +type RemoteClusterStore interface { + Save(rc *model.RemoteCluster) (*model.RemoteCluster, error) + Update(rc *model.RemoteCluster) (*model.RemoteCluster, error) + Delete(remoteClusterId string) (bool, error) + Get(remoteClusterId string) (*model.RemoteCluster, error) + GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) + UpdateTopics(remoteClusterId string, topics string) (*model.RemoteCluster, error) + SetLastPingAt(remoteClusterId string) error +} + type ComplianceStore interface { Save(compliance *model.Compliance) (*model.Compliance, error) Update(compliance *model.Compliance) (*model.Compliance, error) @@ -598,6 +615,7 @@ type ReactionStore interface { Save(reaction *model.Reaction) (*model.Reaction, error) Delete(reaction *model.Reaction) (*model.Reaction, error) GetForPost(postID string, allowFromCache bool) ([]*model.Reaction, error) + GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) DeleteAllWithEmojiName(emojiName string) error PermanentDeleteBatch(endTime int64, limit int64) (int64, error) BulkGetForPosts(postIds []string) ([]*model.Reaction, error) @@ -787,6 +805,36 @@ type LinkMetadataStore interface { Get(url string, timestamp int64) (*model.LinkMetadata, error) } +type SharedChannelStore interface { + Save(sc *model.SharedChannel) (*model.SharedChannel, error) + Get(channelId string) (*model.SharedChannel, error) + HasChannel(channelID string) (bool, error) + GetAll(offset, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) + GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) + Update(sc *model.SharedChannel) (*model.SharedChannel, error) + Delete(channelId string) (bool, error) + + SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) + UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) + GetRemote(id string) (*model.SharedChannelRemote, error) + HasRemote(channelID string, remoteId string) (bool, error) + GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) + GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) + GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) + UpdateRemoteNextSyncAt(id string, syncTime int64) error + DeleteRemote(remoteId string) (bool, error) + GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) + + SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) + GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) + UpdateUserLastSyncAt(id string, syncTime int64) error + + SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) + UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) + GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) + UpdateAttachmentLastSyncAt(id string, syncTime int64) error +} + // ChannelSearchOpts contains options for searching channels. // // NotAssociatedToGroup will exclude channels that have associated, active GroupChannels records. diff --git a/store/storetest/channel_store.go b/store/storetest/channel_store.go index 3396fca513e..00f78811c78 100644 --- a/store/storetest/channel_store.go +++ b/store/storetest/channel_store.go @@ -111,6 +111,8 @@ func TestChannelStore(t *testing.T, ss store.Store, s SqlStore) { t.Run("UpdateSidebarCategories", func(t *testing.T) { testUpdateSidebarCategories(t, ss) }) t.Run("DeleteSidebarCategory", func(t *testing.T) { testDeleteSidebarCategory(t, ss, s) }) t.Run("UpdateSidebarChannelsByPreferences", func(t *testing.T) { testUpdateSidebarChannelsByPreferences(t, ss) }) + t.Run("SetShared", func(t *testing.T) { testSetShared(t, ss) }) + t.Run("GetTeamForChannel", func(t *testing.T) { testGetTeamForChannel(t, ss) }) } func testChannelStoreSave(t *testing.T, ss store.Store) { @@ -6848,3 +6850,60 @@ func testGroupSyncedChannelCount(t *testing.T, ss store.Store) { require.NoError(t, err) require.GreaterOrEqual(t, countAfter, count+1) } + +func testSetShared(t *testing.T, ss store.Store) { + channel := &model.Channel{ + TeamId: model.NewId(), + DisplayName: "test_share_flag", + Name: "test_share_flag", + Type: model.CHANNEL_OPEN, + } + channelSaved, err := ss.Channel().Save(channel, 999) + require.NoError(t, err) + + t.Run("Check default", func(t *testing.T) { + assert.False(t, channelSaved.IsShared()) + }) + + t.Run("Set Shared flag", func(t *testing.T) { + err := ss.Channel().SetShared(channelSaved.Id, true) + require.NoError(t, err) + + channelMod, err := ss.Channel().Get(channelSaved.Id, false) + require.NoError(t, err) + + assert.True(t, channelMod.IsShared()) + }) + + t.Run("Set Shared for invalid id", func(t *testing.T) { + err := ss.Channel().SetShared(model.NewId(), true) + require.Error(t, err) + }) +} + +func testGetTeamForChannel(t *testing.T, ss store.Store) { + team, err := ss.Team().Save(&model.Team{ + Name: "myteam", + DisplayName: "DisplayName", + Email: MakeEmail(), + Type: model.TEAM_OPEN, + }) + require.NoError(t, err) + + channel := &model.Channel{ + TeamId: team.Id, + DisplayName: "test_share_flag", + Name: "test_share_flag", + Type: model.CHANNEL_OPEN, + } + channelSaved, err := ss.Channel().Save(channel, 999) + require.NoError(t, err) + + got, err := ss.Channel().GetTeamForChannel(channelSaved.Id) + require.NoError(t, err) + assert.Equal(t, team.Id, got.Id) + + _, err = ss.Channel().GetTeamForChannel("notfound") + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) +} diff --git a/store/storetest/mocks/ChannelStore.go b/store/storetest/mocks/ChannelStore.go index a5e95337478..1a49bbd60ac 100644 --- a/store/storetest/mocks/ChannelStore.go +++ b/store/storetest/mocks/ChannelStore.go @@ -167,13 +167,20 @@ func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userI return r0, r1, r2 } -// CreateDirectChannel provides a mock function with given fields: userId, otherUserId -func (_m *ChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User) (*model.Channel, error) { - ret := _m.Called(userId, otherUserId) +// CreateDirectChannel provides a mock function with given fields: userId, otherUserId, channelOptions +func (_m *ChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { + _va := make([]interface{}, len(channelOptions)) + for _i := range channelOptions { + _va[_i] = channelOptions[_i] + } + var _ca []interface{} + _ca = append(_ca, userId, otherUserId) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 *model.Channel - if rf, ok := ret.Get(0).(func(*model.User, *model.User) *model.Channel); ok { - r0 = rf(userId, otherUserId) + if rf, ok := ret.Get(0).(func(*model.User, *model.User, ...model.ChannelOption) *model.Channel); ok { + r0 = rf(userId, otherUserId, channelOptions...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Channel) @@ -181,8 +188,8 @@ func (_m *ChannelStore) CreateDirectChannel(userId *model.User, otherUserId *mod } var r1 error - if rf, ok := ret.Get(1).(func(*model.User, *model.User) error); ok { - r1 = rf(userId, otherUserId) + if rf, ok := ret.Get(1).(func(*model.User, *model.User, ...model.ChannelOption) error); ok { + r1 = rf(userId, otherUserId, channelOptions...) } else { r1 = ret.Error(1) } @@ -1273,6 +1280,29 @@ func (_m *ChannelStore) GetTeamChannels(teamID string) (*model.ChannelList, erro return r0, r1 } +// GetTeamForChannel provides a mock function with given fields: channelID +func (_m *ChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) { + ret := _m.Called(channelID) + + var r0 *model.Team + if rf, ok := ret.Get(0).(func(string) *model.Team); ok { + r0 = rf(channelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Team) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GroupSyncedChannelCount provides a mock function with given fields: func (_m *ChannelStore) GroupSyncedChannelCount() (int64, error) { ret := _m.Called() @@ -1771,6 +1801,20 @@ func (_m *ChannelStore) SetDeleteAt(channelID string, deleteAt int64, updateAt i return r0 } +// SetShared provides a mock function with given fields: channelId, shared +func (_m *ChannelStore) SetShared(channelId string, shared bool) error { + ret := _m.Called(channelId, shared) + + var r0 error + if rf, ok := ret.Get(0).(func(string, bool) error); ok { + r0 = rf(channelId, shared) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Update provides a mock function with given fields: channel func (_m *ChannelStore) Update(channel *model.Channel) (*model.Channel, error) { ret := _m.Called(channel) diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index 7e168be423e..0c073ef4caa 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -538,6 +538,29 @@ func (_m *PostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFrom return r0, r1 } +// GetPostsSinceForSync provides a mock function with given fields: options, allowFromCache +func (_m *PostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { + ret := _m.Called(options, allowFromCache) + + var r0 []*model.Post + if rf, ok := ret.Get(0).(func(model.GetPostsSinceForSyncOptions, bool) []*model.Post); ok { + r0 = rf(options, allowFromCache) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Post) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(model.GetPostsSinceForSyncOptions, bool) error); ok { + r1 = rf(options, allowFromCache) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetRepliesForExport provides a mock function with given fields: parentID func (_m *PostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { ret := _m.Called(parentID) @@ -561,13 +584,13 @@ func (_m *PostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExpo return r0, r1 } -// GetSingle provides a mock function with given fields: id -func (_m *PostStore) GetSingle(id string) (*model.Post, error) { - ret := _m.Called(id) +// GetSingle provides a mock function with given fields: id, inclDeleted +func (_m *PostStore) GetSingle(id string, inclDeleted bool) (*model.Post, error) { + ret := _m.Called(id, inclDeleted) var r0 *model.Post - if rf, ok := ret.Get(0).(func(string) *model.Post); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(string, bool) *model.Post); ok { + r0 = rf(id, inclDeleted) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Post) @@ -575,8 +598,8 @@ func (_m *PostStore) GetSingle(id string) (*model.Post, error) { } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(string, bool) error); ok { + r1 = rf(id, inclDeleted) } else { r1 = ret.Error(1) } diff --git a/store/storetest/mocks/ReactionStore.go b/store/storetest/mocks/ReactionStore.go index a50c8894896..c9e78c5198f 100644 --- a/store/storetest/mocks/ReactionStore.go +++ b/store/storetest/mocks/ReactionStore.go @@ -97,6 +97,29 @@ func (_m *ReactionStore) GetForPost(postID string, allowFromCache bool) ([]*mode return r0, r1 } +// GetForPostSince provides a mock function with given fields: postId, since, excludeRemoteId, inclDeleted +func (_m *ReactionStore) GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) { + ret := _m.Called(postId, since, excludeRemoteId, inclDeleted) + + var r0 []*model.Reaction + if rf, ok := ret.Get(0).(func(string, int64, string, bool) []*model.Reaction); ok { + r0 = rf(postId, since, excludeRemoteId, inclDeleted) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Reaction) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, int64, string, bool) error); ok { + r1 = rf(postId, since, excludeRemoteId, inclDeleted) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // PermanentDeleteBatch provides a mock function with given fields: endTime, limit func (_m *ReactionStore) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) { ret := _m.Called(endTime, limit) diff --git a/store/storetest/mocks/RemoteClusterStore.go b/store/storetest/mocks/RemoteClusterStore.go new file mode 100644 index 00000000000..dfa1fc1c349 --- /dev/null +++ b/store/storetest/mocks/RemoteClusterStore.go @@ -0,0 +1,165 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import ( + model "github.com/mattermost/mattermost-server/v5/model" + mock "github.com/stretchr/testify/mock" +) + +// RemoteClusterStore is an autogenerated mock type for the RemoteClusterStore type +type RemoteClusterStore struct { + mock.Mock +} + +// Delete provides a mock function with given fields: remoteClusterId +func (_m *RemoteClusterStore) Delete(remoteClusterId string) (bool, error) { + ret := _m.Called(remoteClusterId) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(remoteClusterId) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(remoteClusterId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: remoteClusterId +func (_m *RemoteClusterStore) Get(remoteClusterId string) (*model.RemoteCluster, error) { + ret := _m.Called(remoteClusterId) + + var r0 *model.RemoteCluster + if rf, ok := ret.Get(0).(func(string) *model.RemoteCluster); ok { + r0 = rf(remoteClusterId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(remoteClusterId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: filter +func (_m *RemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) { + ret := _m.Called(filter) + + var r0 []*model.RemoteCluster + if rf, ok := ret.Get(0).(func(model.RemoteClusterQueryFilter) []*model.RemoteCluster); ok { + r0 = rf(filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(model.RemoteClusterQueryFilter) error); ok { + r1 = rf(filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: rc +func (_m *RemoteClusterStore) Save(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + ret := _m.Called(rc) + + var r0 *model.RemoteCluster + if rf, ok := ret.Get(0).(func(*model.RemoteCluster) *model.RemoteCluster); ok { + r0 = rf(rc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.RemoteCluster) error); ok { + r1 = rf(rc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SetLastPingAt provides a mock function with given fields: remoteClusterId +func (_m *RemoteClusterStore) SetLastPingAt(remoteClusterId string) error { + ret := _m.Called(remoteClusterId) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(remoteClusterId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Update provides a mock function with given fields: rc +func (_m *RemoteClusterStore) Update(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + ret := _m.Called(rc) + + var r0 *model.RemoteCluster + if rf, ok := ret.Get(0).(func(*model.RemoteCluster) *model.RemoteCluster); ok { + r0 = rf(rc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.RemoteCluster) error); ok { + r1 = rf(rc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateTopics provides a mock function with given fields: remoteClusterId, topics +func (_m *RemoteClusterStore) UpdateTopics(remoteClusterId string, topics string) (*model.RemoteCluster, error) { + ret := _m.Called(remoteClusterId, topics) + + var r0 *model.RemoteCluster + if rf, ok := ret.Get(0).(func(string, string) *model.RemoteCluster); ok { + r0 = rf(remoteClusterId, topics) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(remoteClusterId, topics) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/store/storetest/mocks/SharedChannelStore.go b/store/storetest/mocks/SharedChannelStore.go new file mode 100644 index 00000000000..7aa6a4ba421 --- /dev/null +++ b/store/storetest/mocks/SharedChannelStore.go @@ -0,0 +1,528 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import ( + model "github.com/mattermost/mattermost-server/v5/model" + mock "github.com/stretchr/testify/mock" +) + +// SharedChannelStore is an autogenerated mock type for the SharedChannelStore type +type SharedChannelStore struct { + mock.Mock +} + +// Delete provides a mock function with given fields: channelId +func (_m *SharedChannelStore) Delete(channelId string) (bool, error) { + ret := _m.Called(channelId) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(channelId) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteRemote provides a mock function with given fields: remoteId +func (_m *SharedChannelStore) DeleteRemote(remoteId string) (bool, error) { + ret := _m.Called(remoteId) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(remoteId) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(remoteId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: channelId +func (_m *SharedChannelStore) Get(channelId string) (*model.SharedChannel, error) { + ret := _m.Called(channelId) + + var r0 *model.SharedChannel + if rf, ok := ret.Get(0).(func(string) *model.SharedChannel); ok { + r0 = rf(channelId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannel) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: offset, limit, opts +func (_m *SharedChannelStore) GetAll(offset int, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) { + ret := _m.Called(offset, limit, opts) + + var r0 []*model.SharedChannel + if rf, ok := ret.Get(0).(func(int, int, model.SharedChannelFilterOpts) []*model.SharedChannel); ok { + r0 = rf(offset, limit, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.SharedChannel) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int, int, model.SharedChannelFilterOpts) error); ok { + r1 = rf(offset, limit, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAllCount provides a mock function with given fields: opts +func (_m *SharedChannelStore) GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) { + ret := _m.Called(opts) + + var r0 int64 + if rf, ok := ret.Get(0).(func(model.SharedChannelFilterOpts) int64); ok { + r0 = rf(opts) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(model.SharedChannelFilterOpts) error); ok { + r1 = rf(opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAttachment provides a mock function with given fields: fileId, remoteId +func (_m *SharedChannelStore) GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) { + ret := _m.Called(fileId, remoteId) + + var r0 *model.SharedChannelAttachment + if rf, ok := ret.Get(0).(func(string, string) *model.SharedChannelAttachment); ok { + r0 = rf(fileId, remoteId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelAttachment) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(fileId, remoteId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRemote provides a mock function with given fields: id +func (_m *SharedChannelStore) GetRemote(id string) (*model.SharedChannelRemote, error) { + ret := _m.Called(id) + + var r0 *model.SharedChannelRemote + if rf, ok := ret.Get(0).(func(string) *model.SharedChannelRemote); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelRemote) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRemoteByIds provides a mock function with given fields: channelId, remoteId +func (_m *SharedChannelStore) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) { + ret := _m.Called(channelId, remoteId) + + var r0 *model.SharedChannelRemote + if rf, ok := ret.Get(0).(func(string, string) *model.SharedChannelRemote); ok { + r0 = rf(channelId, remoteId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelRemote) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channelId, remoteId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRemoteForUser provides a mock function with given fields: remoteId, userId +func (_m *SharedChannelStore) GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) { + ret := _m.Called(remoteId, userId) + + var r0 *model.RemoteCluster + if rf, ok := ret.Get(0).(func(string, string) *model.RemoteCluster); ok { + r0 = rf(remoteId, userId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.RemoteCluster) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(remoteId, userId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRemotes provides a mock function with given fields: opts +func (_m *SharedChannelStore) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + ret := _m.Called(opts) + + var r0 []*model.SharedChannelRemote + if rf, ok := ret.Get(0).(func(model.SharedChannelRemoteFilterOpts) []*model.SharedChannelRemote); ok { + r0 = rf(opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.SharedChannelRemote) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(model.SharedChannelRemoteFilterOpts) error); ok { + r1 = rf(opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRemotesStatus provides a mock function with given fields: channelId +func (_m *SharedChannelStore) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) { + ret := _m.Called(channelId) + + var r0 []*model.SharedChannelRemoteStatus + if rf, ok := ret.Get(0).(func(string) []*model.SharedChannelRemoteStatus); ok { + r0 = rf(channelId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.SharedChannelRemoteStatus) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetUser provides a mock function with given fields: userId, remoteId +func (_m *SharedChannelStore) GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) { + ret := _m.Called(userId, remoteId) + + var r0 *model.SharedChannelUser + if rf, ok := ret.Get(0).(func(string, string) *model.SharedChannelUser); ok { + r0 = rf(userId, remoteId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelUser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(userId, remoteId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HasChannel provides a mock function with given fields: channelID +func (_m *SharedChannelStore) HasChannel(channelID string) (bool, error) { + ret := _m.Called(channelID) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(channelID) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HasRemote provides a mock function with given fields: channelID, remoteId +func (_m *SharedChannelStore) HasRemote(channelID string, remoteId string) (bool, error) { + ret := _m.Called(channelID, remoteId) + + var r0 bool + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(channelID, remoteId) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channelID, remoteId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: sc +func (_m *SharedChannelStore) Save(sc *model.SharedChannel) (*model.SharedChannel, error) { + ret := _m.Called(sc) + + var r0 *model.SharedChannel + if rf, ok := ret.Get(0).(func(*model.SharedChannel) *model.SharedChannel); ok { + r0 = rf(sc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannel) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannel) error); ok { + r1 = rf(sc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SaveAttachment provides a mock function with given fields: remote +func (_m *SharedChannelStore) SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) { + ret := _m.Called(remote) + + var r0 *model.SharedChannelAttachment + if rf, ok := ret.Get(0).(func(*model.SharedChannelAttachment) *model.SharedChannelAttachment); ok { + r0 = rf(remote) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelAttachment) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannelAttachment) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SaveRemote provides a mock function with given fields: remote +func (_m *SharedChannelStore) SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + ret := _m.Called(remote) + + var r0 *model.SharedChannelRemote + if rf, ok := ret.Get(0).(func(*model.SharedChannelRemote) *model.SharedChannelRemote); ok { + r0 = rf(remote) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelRemote) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannelRemote) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SaveUser provides a mock function with given fields: remote +func (_m *SharedChannelStore) SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) { + ret := _m.Called(remote) + + var r0 *model.SharedChannelUser + if rf, ok := ret.Get(0).(func(*model.SharedChannelUser) *model.SharedChannelUser); ok { + r0 = rf(remote) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelUser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannelUser) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Update provides a mock function with given fields: sc +func (_m *SharedChannelStore) Update(sc *model.SharedChannel) (*model.SharedChannel, error) { + ret := _m.Called(sc) + + var r0 *model.SharedChannel + if rf, ok := ret.Get(0).(func(*model.SharedChannel) *model.SharedChannel); ok { + r0 = rf(sc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannel) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannel) error); ok { + r1 = rf(sc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateAttachmentLastSyncAt provides a mock function with given fields: id, syncTime +func (_m *SharedChannelStore) UpdateAttachmentLastSyncAt(id string, syncTime int64) error { + ret := _m.Called(id, syncTime) + + var r0 error + if rf, ok := ret.Get(0).(func(string, int64) error); ok { + r0 = rf(id, syncTime) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateRemote provides a mock function with given fields: remote +func (_m *SharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + ret := _m.Called(remote) + + var r0 *model.SharedChannelRemote + if rf, ok := ret.Get(0).(func(*model.SharedChannelRemote) *model.SharedChannelRemote); ok { + r0 = rf(remote) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.SharedChannelRemote) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannelRemote) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateRemoteNextSyncAt provides a mock function with given fields: id, syncTime +func (_m *SharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { + ret := _m.Called(id, syncTime) + + var r0 error + if rf, ok := ret.Get(0).(func(string, int64) error); ok { + r0 = rf(id, syncTime) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateUserLastSyncAt provides a mock function with given fields: id, syncTime +func (_m *SharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { + ret := _m.Called(id, syncTime) + + var r0 error + if rf, ok := ret.Get(0).(func(string, int64) error); ok { + r0 = rf(id, syncTime) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpsertAttachment provides a mock function with given fields: remote +func (_m *SharedChannelStore) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) { + ret := _m.Called(remote) + + var r0 string + if rf, ok := ret.Get(0).(func(*model.SharedChannelAttachment) string); ok { + r0 = rf(remote) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(*model.SharedChannelAttachment) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/store/storetest/mocks/Store.go b/store/storetest/mocks/Store.go index b3da958e01e..e38db2b1d58 100644 --- a/store/storetest/mocks/Store.go +++ b/store/storetest/mocks/Store.go @@ -432,6 +432,22 @@ func (_m *Store) RecycleDBConnections(d time.Duration) { _m.Called(d) } +// RemoteCluster provides a mock function with given fields: +func (_m *Store) RemoteCluster() store.RemoteClusterStore { + ret := _m.Called() + + var r0 store.RemoteClusterStore + if rf, ok := ret.Get(0).(func() store.RemoteClusterStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.RemoteClusterStore) + } + } + + return r0 +} + // ReplicaLagAbs provides a mock function with given fields: func (_m *Store) ReplicaLagAbs() error { ret := _m.Called() @@ -513,6 +529,22 @@ func (_m *Store) SetContext(_a0 context.Context) { _m.Called(_a0) } +// SharedChannel provides a mock function with given fields: +func (_m *Store) SharedChannel() store.SharedChannelStore { + ret := _m.Called() + + var r0 store.SharedChannelStore + if rf, ok := ret.Get(0).(func() store.SharedChannelStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.SharedChannelStore) + } + } + + return r0 +} + // Status provides a mock function with given fields: func (_m *Store) Status() store.StatusStore { ret := _m.Called() diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index cbf428aeca2..97316898e04 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -138,7 +138,7 @@ func testPostStoreSave(t *testing.T, ss store.Store) { _, err = ss.Post().Save(&replyPost) require.NoError(t, err) - rrootPost, err := ss.Post().GetSingle(rootPost.Id) + rrootPost, err := ss.Post().GetSingle(rootPost.Id, false) require.NoError(t, err) assert.Greater(t, rrootPost.UpdateAt, rootPost.UpdateAt) }) @@ -226,7 +226,7 @@ func testPostStoreSaveMultiple(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, -1, errIdx) for _, post := range newPosts { - storedPost, err := ss.Post().GetSingle(post.Id) + storedPost, err := ss.Post().GetSingle(post.Id, false) assert.NoError(t, err) assert.Equal(t, post.ChannelId, storedPost.ChannelId) assert.Equal(t, post.Message, storedPost.Message) @@ -273,13 +273,13 @@ func testPostStoreSaveMultiple(t *testing.T, ss store.Store) { require.Error(t, err) require.Equal(t, 1, errIdx) require.Nil(t, newPosts) - storedPost, err := ss.Post().GetSingle(p3.Id) + storedPost, err := ss.Post().GetSingle(p3.Id, false) assert.NoError(t, err) assert.Equal(t, p3.ChannelId, storedPost.ChannelId) assert.Equal(t, p3.Message, storedPost.Message) assert.Equal(t, p3.UserId, storedPost.UserId) - storedPost, err = ss.Post().GetSingle(p4.Id) + storedPost, err = ss.Post().GetSingle(p4.Id, false) assert.Error(t, err) assert.Nil(t, storedPost) }) @@ -299,7 +299,7 @@ func testPostStoreSaveMultiple(t *testing.T, ss store.Store) { _, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost, &replyPost}) require.NoError(t, err) - rrootPost, err := ss.Post().GetSingle(rootPost.Id) + rrootPost, err := ss.Post().GetSingle(rootPost.Id, false) require.NoError(t, err) assert.Equal(t, rrootPost.UpdateAt, rootPost.UpdateAt) @@ -318,7 +318,7 @@ func testPostStoreSaveMultiple(t *testing.T, ss store.Store) { _, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost2, &replyPost3}) require.NoError(t, err) - rrootPost2, err := ss.Post().GetSingle(rootPost.Id) + rrootPost2, err := ss.Post().GetSingle(rootPost.Id, false) require.NoError(t, err) assert.Greater(t, rrootPost2.UpdateAt, rrootPost.UpdateAt) }) @@ -459,14 +459,33 @@ func testPostStoreGetSingle(t *testing.T, ss store.Store) { o1.UserId = model.NewId() o1.Message = "zz" + model.NewId() + "b" + o2 := &model.Post{} + o2.ChannelId = o1.ChannelId + o2.UserId = o1.UserId + o2.Message = "zz" + model.NewId() + "c" + o1, err := ss.Post().Save(o1) require.NoError(t, err) - post, err := ss.Post().GetSingle(o1.Id) + o2, err = ss.Post().Save(o2) + require.NoError(t, err) + + err = ss.Post().Delete(o2.Id, model.GetMillis(), o2.UserId) + require.NoError(t, err) + + post, err := ss.Post().GetSingle(o1.Id, false) require.NoError(t, err) require.Equal(t, post.CreateAt, o1.CreateAt, "invalid returned post") - _, err = ss.Post().GetSingle("123") + post, err = ss.Post().GetSingle(o2.Id, false) + require.Error(t, err, "should not return deleted post") + + post, err = ss.Post().GetSingle(o2.Id, true) + require.NoError(t, err) + require.Equal(t, post.CreateAt, o2.CreateAt, "invalid returned post") + require.NotZero(t, post.DeleteAt, "DeleteAt should be non-zero") + + _, err = ss.Post().GetSingle("123", false) require.Error(t, err, "Missing id should have failed") } diff --git a/store/storetest/reaction_store.go b/store/storetest/reaction_store.go index 1d1d46f9cf1..f683d2f13b8 100644 --- a/store/storetest/reaction_store.go +++ b/store/storetest/reaction_store.go @@ -5,6 +5,7 @@ package storetest import ( "context" + "errors" "sync" "testing" "time" @@ -21,6 +22,7 @@ func TestReactionStore(t *testing.T, ss store.Store, s SqlStore) { t.Run("ReactionSave", func(t *testing.T) { testReactionSave(t, ss) }) t.Run("ReactionDelete", func(t *testing.T) { testReactionDelete(t, ss) }) t.Run("ReactionGetForPost", func(t *testing.T) { testReactionGetForPost(t, ss) }) + t.Run("ReactionGetForPostSince", func(t *testing.T) { testReactionGetForPostSince(t, ss, s) }) t.Run("ReactionDeleteAllWithEmojiName", func(t *testing.T) { testReactionDeleteAllWithEmojiName(t, ss, s) }) t.Run("PermanentDeleteBatch", func(t *testing.T) { testReactionStorePermanentDeleteBatch(t, ss) }) t.Run("ReactionBulkGetForPosts", func(t *testing.T) { testReactionBulkGetForPosts(t, ss) }) @@ -270,6 +272,160 @@ func testReactionGetForPost(t *testing.T, ss store.Store) { } } +func testReactionGetForPostSince(t *testing.T, ss store.Store, s SqlStore) { + now := model.GetMillis() + later := now + 1800000 // add 30 minutes + remoteId := model.NewId() + + postId := model.NewId() + userId := model.NewId() + reactions := []*model.Reaction{ + { + UserId: userId, + PostId: postId, + EmojiName: "smile", + UpdateAt: later, + }, + { + UserId: model.NewId(), + PostId: postId, + EmojiName: "smile", + }, + { + UserId: userId, + PostId: postId, + EmojiName: "sad", + UpdateAt: later, + RemoteId: &remoteId, + }, + { + UserId: userId, + PostId: model.NewId(), + EmojiName: "angry", + }, + { + UserId: userId, + PostId: postId, + EmojiName: "angry", + DeleteAt: now + 1, + UpdateAt: later, + }, + } + + for _, reaction := range reactions { + delete := reaction.DeleteAt + update := reaction.UpdateAt + + _, err := ss.Reaction().Save(reaction) + require.Nil(t, err) + + if delete > 0 { + _, err = ss.Reaction().Delete(reaction) + require.Nil(t, err) + } + if update > 0 { + err = forceUpdateAt(reaction, update, s) + require.Nil(t, err) + } + err = forceNULL(reaction, s) // test COALESCE + require.Nil(t, err) + } + + t.Run("reactions since", func(t *testing.T) { + // should return 2 reactions that are not deleted for post + returned, err := ss.Reaction().GetForPostSince(postId, later-1, "", false) + require.Nil(t, err) + require.Len(t, returned, 2, "should've returned 2 non-deleted reactions") + for _, r := range returned { + assert.Zero(t, r.DeleteAt, "should not have returned deleted reaction") + } + + }) + + t.Run("reactions since, incl deleted", func(t *testing.T) { + // should return 3 reactions for post, including one deleted + returned, err := ss.Reaction().GetForPostSince(postId, later-1, "", true) + require.Nil(t, err) + require.Len(t, returned, 3, "should've returned 3 reactions") + var count int + for _, r := range returned { + if r.DeleteAt > 0 { + count++ + } + } + assert.Equal(t, 1, count, "should not have returned 1 deleted reaction") + + }) + + t.Run("reactions since, filter remoteId", func(t *testing.T) { + // should return 1 reactions that are not deleted for post and have no remoteId + returned, err := ss.Reaction().GetForPostSince(postId, later-1, remoteId, false) + require.Nil(t, err) + require.Len(t, returned, 1, "should've returned 1 filtered reactions") + for _, r := range returned { + assert.Zero(t, r.DeleteAt, "should not have returned deleted reaction") + } + }) + + t.Run("reactions since, invalid post", func(t *testing.T) { + // should return 0 reactions for invalid post + returned, err := ss.Reaction().GetForPostSince(model.NewId(), later-1, "", true) + require.Nil(t, err) + require.Empty(t, returned, "should've returned 0 reactions") + }) + + t.Run("reactions since, far future", func(t *testing.T) { + // should return 0 reactions for since far in the future + returned, err := ss.Reaction().GetForPostSince(postId, later*2, "", true) + require.Nil(t, err) + require.Empty(t, returned, "should've returned 0 reactions") + }) +} + +func forceUpdateAt(reaction *model.Reaction, updateAt int64, s SqlStore) error { + params := map[string]interface{}{ + "UserId": reaction.UserId, + "PostId": reaction.PostId, + "EmojiName": reaction.EmojiName, + "UpdateAt": updateAt, + } + + sqlResult, err := s.GetMaster().Exec(` + UPDATE + Reactions + SET + UpdateAt=:UpdateAt + WHERE + UserId = :UserId AND + PostId = :PostId AND + EmojiName = :EmojiName`, params, + ) + + if err != nil { + return err + } + + rows, err := sqlResult.RowsAffected() + if err != nil { + return err + } + + if rows != 1 { + return errors.New("expected one row affected") + } + return nil +} + +func forceNULL(reaction *model.Reaction, s SqlStore) error { + if _, err := s.GetMaster().Exec(`UPDATE Reactions SET UpdateAt = NULL WHERE UpdateAt = 0`); err != nil { + return err + } + if _, err := s.GetMaster().Exec(`UPDATE Reactions SET DeleteAt = NULL WHERE DeleteAt = 0`); err != nil { + return err + } + return nil +} + func testReactionDeleteAllWithEmojiName(t *testing.T, ss store.Store, s SqlStore) { emojiToDelete := model.NewId() @@ -321,28 +477,16 @@ func testReactionDeleteAllWithEmojiName(t *testing.T, ss store.Store, s SqlStore for _, reaction := range reactions { _, err := ss.Reaction().Save(reaction) - require.NoError(t, err) + require.Nil(t, err) + + // make at least one Reaction record contain NULL for Update and DeleteAt to simulate post schema upgrade case. + if reaction.EmojiName == emojiToDelete { + err = forceNULL(reaction, s) + require.Nil(t, err) + } } - // make at least one Reaction record contain NULL for Update and DeleteAt to simulate post schema upgrade case. - sqlResult, err := s.GetMaster().Exec(` - UPDATE - Reactions - SET - UpdateAt=NULL, DeleteAt=NULL - WHERE - UserId = :UserId AND PostId = :PostId AND EmojiName = :EmojiName`, - map[string]interface{}{ - "UserId": userId, - "PostId": post.Id, - "EmojiName": emojiToDelete, - }) - require.NoError(t, err) - rowsAffected, err := sqlResult.RowsAffected() - require.NoError(t, err) - require.NotZero(t, rowsAffected) - - err = ss.Reaction().DeleteAllWithEmojiName(emojiToDelete) + err := ss.Reaction().DeleteAllWithEmojiName(emojiToDelete) require.NoError(t, err) // check that the reactions were deleted diff --git a/store/storetest/remote_cluster_store.go b/store/storetest/remote_cluster_store.go new file mode 100644 index 00000000000..405106b2608 --- /dev/null +++ b/store/storetest/remote_cluster_store.go @@ -0,0 +1,521 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package storetest + +import ( + "strings" + "testing" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/store" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoteClusterStore(t *testing.T, ss store.Store) { + t.Run("RemoteClusterGetAllInChannel", func(t *testing.T) { testRemoteClusterGetAllInChannel(t, ss) }) + t.Run("RemoteClusterGetAllNotInChannel", func(t *testing.T) { testRemoteClusterGetAllNotInChannel(t, ss) }) + t.Run("RemoteClusterSave", func(t *testing.T) { testRemoteClusterSave(t, ss) }) + t.Run("RemoteClusterDelete", func(t *testing.T) { testRemoteClusterDelete(t, ss) }) + t.Run("RemoteClusterGet", func(t *testing.T) { testRemoteClusterGet(t, ss) }) + t.Run("RemoteClusterGetAll", func(t *testing.T) { testRemoteClusterGetAll(t, ss) }) + t.Run("RemoteClusterGetByTopic", func(t *testing.T) { testRemoteClusterGetByTopic(t, ss) }) + t.Run("RemoteClusterUpdateTopics", func(t *testing.T) { testRemoteClusterUpdateTopics(t, ss) }) +} + +func testRemoteClusterSave(t *testing.T, ss store.Store) { + + t.Run("Save", func(t *testing.T) { + rc := &model.RemoteCluster{ + DisplayName: "some remote", + SiteURL: "somewhere.com", + CreatorId: model.NewId(), + } + + rcSaved, err := ss.RemoteCluster().Save(rc) + require.Nil(t, err) + require.Equal(t, rc.DisplayName, rcSaved.DisplayName) + require.Equal(t, rc.SiteURL, rcSaved.SiteURL) + require.Greater(t, rc.CreateAt, int64(0)) + require.Equal(t, rc.LastPingAt, int64(0)) + }) + + t.Run("Save missing display name", func(t *testing.T) { + rc := &model.RemoteCluster{ + SiteURL: "somewhere.com", + CreatorId: model.NewId(), + } + _, err := ss.RemoteCluster().Save(rc) + require.NotNil(t, err) + }) + + t.Run("Save missing creator id", func(t *testing.T) { + rc := &model.RemoteCluster{ + DisplayName: "some remote", + SiteURL: "somewhere.com", + } + _, err := ss.RemoteCluster().Save(rc) + require.NotNil(t, err) + }) +} + +func testRemoteClusterDelete(t *testing.T, ss store.Store) { + t.Run("Delete", func(t *testing.T) { + rc := &model.RemoteCluster{ + DisplayName: "shortlived remote", + SiteURL: "nowhere.com", + CreatorId: model.NewId(), + } + rcSaved, err := ss.RemoteCluster().Save(rc) + require.Nil(t, err) + + deleted, err := ss.RemoteCluster().Delete(rcSaved.RemoteId) + require.Nil(t, err) + require.True(t, deleted) + }) + + t.Run("Delete nonexistent", func(t *testing.T) { + deleted, err := ss.RemoteCluster().Delete(model.NewId()) + require.Nil(t, err) + require.False(t, deleted) + }) +} + +func testRemoteClusterGet(t *testing.T, ss store.Store) { + t.Run("Get", func(t *testing.T) { + rc := &model.RemoteCluster{ + DisplayName: "shortlived remote", + SiteURL: "nowhere.com", + CreatorId: model.NewId(), + } + rcSaved, err := ss.RemoteCluster().Save(rc) + require.Nil(t, err) + + rcGet, err := ss.RemoteCluster().Get(rcSaved.RemoteId) + require.Nil(t, err) + require.Equal(t, rcSaved.RemoteId, rcGet.RemoteId) + }) + + t.Run("Get not found", func(t *testing.T) { + _, err := ss.RemoteCluster().Get(model.NewId()) + require.NotNil(t, err) + }) +} + +func testRemoteClusterGetAll(t *testing.T, ss store.Store) { + require.NoError(t, clearRemoteClusters(ss)) + + userId := model.NewId() + now := model.GetMillis() + pingLongAgo := model.GetMillis() - (model.RemoteOfflineAfterMillis * 3) + + data := []*model.RemoteCluster{ + {DisplayName: "offline remote", CreatorId: userId, SiteURL: "somewhere.com", LastPingAt: pingLongAgo, Topics: " shared incident "}, + {DisplayName: "some online remote", CreatorId: userId, SiteURL: "nowhere.com", LastPingAt: now, Topics: " shared incident "}, + {DisplayName: "another online remote", CreatorId: model.NewId(), SiteURL: "underwhere.com", LastPingAt: now, Topics: ""}, + {DisplayName: "another offline remote", CreatorId: model.NewId(), SiteURL: "knowhere.com", LastPingAt: pingLongAgo, Topics: " shared "}, + {DisplayName: "brand new offline remote", CreatorId: userId, SiteURL: "", LastPingAt: 0, Topics: " bogus shared stuff "}, + } + + idsAll := make([]string, 0) + idsOnline := make([]string, 0) + idsOffline := make([]string, 0) + idsShareTopic := make([]string, 0) + + for _, item := range data { + online := item.LastPingAt == now + saved, err := ss.RemoteCluster().Save(item) + require.Nil(t, err) + idsAll = append(idsAll, saved.RemoteId) + if online { + idsOnline = append(idsOnline, saved.RemoteId) + } else { + idsOffline = append(idsOffline, saved.RemoteId) + } + if strings.Contains(saved.Topics, " shared ") { + idsShareTopic = append(idsShareTopic, saved.RemoteId) + } + } + + t.Run("GetAll", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{} + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure all the test data remotes were returned. + ids := getIds(remotes) + assert.ElementsMatch(t, ids, idsAll) + }) + + t.Run("GetAll online only", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: true, + } + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure all the online remotes were returned. + ids := getIds(remotes) + assert.ElementsMatch(t, ids, idsOnline) + }) + + t.Run("GetAll by topic", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + Topic: "shared", + } + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure only correct topic returned + ids := getIds(remotes) + assert.ElementsMatch(t, ids, idsShareTopic) + }) + + t.Run("GetAll online by topic", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: true, + Topic: "shared", + } + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure only online remotes were returned. + ids := getIds(remotes) + assert.Subset(t, idsOnline, ids) + // make sure correct topic returned + assert.Subset(t, idsShareTopic, ids) + assert.Len(t, ids, 1) + }) + + t.Run("GetAll by Creator", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + CreatorId: userId, + } + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure only correct creator returned + assert.Len(t, remotes, 3) + for _, rc := range remotes { + assert.Equal(t, userId, rc.CreatorId) + } + }) + + t.Run("GetAll by Confirmed", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + OnlyConfirmed: true, + } + remotes, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + // make sure only confirmed returned + assert.Len(t, remotes, 4) + for _, rc := range remotes { + assert.NotEmpty(t, rc.SiteURL) + } + }) +} + +func testRemoteClusterGetAllInChannel(t *testing.T, ss store.Store) { + require.NoError(t, clearRemoteClusters(ss)) + now := model.GetMillis() + + userId := model.NewId() + + channel1, err := createTestChannel(ss, "channel_1") + require.Nil(t, err) + + channel2, err := createTestChannel(ss, "channel_2") + require.Nil(t, err) + + channel3, err := createTestChannel(ss, "channel_3") + require.Nil(t, err) + + // Create shared channels + scData := []*model.SharedChannel{ + {ChannelId: channel1.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_1", CreatorId: model.NewId()}, + {ChannelId: channel2.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_2", CreatorId: model.NewId()}, + {ChannelId: channel3.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_3", CreatorId: model.NewId()}, + } + for _, item := range scData { + _, err := ss.SharedChannel().Save(item) + require.Nil(t, err) + } + + // Create some remote clusters + rcData := []*model.RemoteCluster{ + {DisplayName: "AAAA Inc", CreatorId: userId, SiteURL: "aaaa.com", RemoteId: model.NewId(), LastPingAt: now}, + {DisplayName: "BBBB Inc", CreatorId: userId, SiteURL: "bbbb.com", RemoteId: model.NewId(), LastPingAt: 0}, + {DisplayName: "CCCC Inc", CreatorId: userId, SiteURL: "cccc.com", RemoteId: model.NewId(), LastPingAt: now}, + {DisplayName: "DDDD Inc", CreatorId: userId, SiteURL: "dddd.com", RemoteId: model.NewId(), LastPingAt: now}, + {DisplayName: "EEEE Inc", CreatorId: userId, SiteURL: "eeee.com", RemoteId: model.NewId(), LastPingAt: 0}, + } + for _, item := range rcData { + _, err := ss.RemoteCluster().Save(item) + require.Nil(t, err) + } + + // Create some shared channel remotes + scrData := []*model.SharedChannelRemote{ + {ChannelId: channel1.Id, Description: "AAA Inc Share", RemoteId: rcData[0].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel1.Id, Description: "BBB Inc Share", RemoteId: rcData[1].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel2.Id, Description: "CCC Inc Share", RemoteId: rcData[2].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel2.Id, Description: "DDD Inc Share", RemoteId: rcData[3].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel2.Id, Description: "EEE Inc Share", RemoteId: rcData[4].RemoteId, CreatorId: model.NewId()}, + } + for _, item := range scrData { + _, err := ss.SharedChannel().SaveRemote(item) + require.Nil(t, err) + } + + t.Run("Channel 1", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + InChannel: channel1.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 2, "channel 1 should have 2 remote clusters") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[0].RemoteId, rcData[1].RemoteId}, ids) + }) + + t.Run("Channel 1 online only", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: true, + InChannel: channel1.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 1, "channel 1 should have 1 online remote clusters") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[0].RemoteId}, ids) + }) + + t.Run("Channel 2", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + InChannel: channel2.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 3, "channel 2 should have 3 remote clusters") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[2].RemoteId, rcData[3].RemoteId, rcData[4].RemoteId}, ids) + }) + + t.Run("Channel 2 online only", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + ExcludeOffline: true, + InChannel: channel2.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 2, "channel 2 should have 2 online remote clusters") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[2].RemoteId, rcData[3].RemoteId}, ids) + }) + + t.Run("Channel 3", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + InChannel: channel3.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Empty(t, list, "channel 3 should have 0 remote clusters") + }) +} + +func testRemoteClusterGetAllNotInChannel(t *testing.T, ss store.Store) { + require.NoError(t, clearRemoteClusters(ss)) + + userId := model.NewId() + + channel1, err := createTestChannel(ss, "channel_1") + require.Nil(t, err) + + channel2, err := createTestChannel(ss, "channel_2") + require.Nil(t, err) + + channel3, err := createTestChannel(ss, "channel_3") + require.Nil(t, err) + + // Create shared channels + scData := []*model.SharedChannel{ + {ChannelId: channel1.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_1", CreatorId: model.NewId()}, + {ChannelId: channel2.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_2", CreatorId: model.NewId()}, + {ChannelId: channel3.Id, TeamId: model.NewId(), Home: true, ShareName: "test_chan_3", CreatorId: model.NewId()}, + } + for _, item := range scData { + _, err := ss.SharedChannel().Save(item) + require.Nil(t, err) + } + + // Create some remote clusters + rcData := []*model.RemoteCluster{ + {DisplayName: "AAAA Inc", CreatorId: userId, SiteURL: "aaaa.com", RemoteId: model.NewId()}, + {DisplayName: "BBBB Inc", CreatorId: userId, SiteURL: "bbbb.com", RemoteId: model.NewId()}, + {DisplayName: "CCCC Inc", CreatorId: userId, SiteURL: "cccc.com", RemoteId: model.NewId()}, + {DisplayName: "DDDD Inc", CreatorId: userId, SiteURL: "dddd.com", RemoteId: model.NewId()}, + {DisplayName: "EEEE Inc", CreatorId: userId, SiteURL: "eeee.com", RemoteId: model.NewId()}, + } + for _, item := range rcData { + _, err := ss.RemoteCluster().Save(item) + require.Nil(t, err) + } + + // Create some shared channel remotes + scrData := []*model.SharedChannelRemote{ + {ChannelId: channel1.Id, Description: "AAA Inc Share", RemoteId: rcData[0].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel1.Id, Description: "BBB Inc Share", RemoteId: rcData[1].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel2.Id, Description: "CCC Inc Share", RemoteId: rcData[2].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel2.Id, Description: "DDD Inc Share", RemoteId: rcData[3].RemoteId, CreatorId: model.NewId()}, + {ChannelId: channel3.Id, Description: "EEE Inc Share", RemoteId: rcData[4].RemoteId, CreatorId: model.NewId()}, + } + for _, item := range scrData { + _, err := ss.SharedChannel().SaveRemote(item) + require.Nil(t, err) + } + + t.Run("Channel 1", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + NotInChannel: channel1.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 3, "channel 1 should have 3 remote clusters that are not already members") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[2].RemoteId, rcData[3].RemoteId, rcData[4].RemoteId}, ids) + }) + + t.Run("Channel 2", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + NotInChannel: channel2.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 3, "channel 2 should have 3 remote clusters that are not already members") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[0].RemoteId, rcData[1].RemoteId, rcData[4].RemoteId}, ids) + }) + + t.Run("Channel 3", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + NotInChannel: channel3.Id, + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 4, "channel 3 should have 4 remote clusters that are not already members") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[0].RemoteId, rcData[1].RemoteId, rcData[2].RemoteId, rcData[3].RemoteId}, ids) + }) + + t.Run("Channel with no share remotes", func(t *testing.T) { + filter := model.RemoteClusterQueryFilter{ + NotInChannel: model.NewId(), + } + list, err := ss.RemoteCluster().GetAll(filter) + require.Nil(t, err) + require.Len(t, list, 5, "should have 5 remote clusters that are not already members") + ids := getIds(list) + require.ElementsMatch(t, []string{rcData[0].RemoteId, rcData[1].RemoteId, rcData[2].RemoteId, rcData[3].RemoteId, + rcData[4].RemoteId}, ids) + }) +} + +func getIds(remotes []*model.RemoteCluster) []string { + ids := make([]string, 0, len(remotes)) + for _, r := range remotes { + ids = append(ids, r.RemoteId) + } + return ids +} + +func testRemoteClusterGetByTopic(t *testing.T, ss store.Store) { + require.NoError(t, clearRemoteClusters(ss)) + + rcData := []*model.RemoteCluster{ + {DisplayName: "AAAA Inc", CreatorId: model.NewId(), SiteURL: "aaaa.com", RemoteId: model.NewId(), Topics: ""}, + {DisplayName: "BBBB Inc", CreatorId: model.NewId(), SiteURL: "bbbb.com", RemoteId: model.NewId(), Topics: " share "}, + {DisplayName: "CCCC Inc", CreatorId: model.NewId(), SiteURL: "cccc.com", RemoteId: model.NewId(), Topics: " incident share "}, + {DisplayName: "DDDD Inc", CreatorId: model.NewId(), SiteURL: "dddd.com", RemoteId: model.NewId(), Topics: " bogus "}, + {DisplayName: "EEEE Inc", CreatorId: model.NewId(), SiteURL: "eeee.com", RemoteId: model.NewId(), Topics: " logs share incident "}, + {DisplayName: "FFFF Inc", CreatorId: model.NewId(), SiteURL: "ffff.com", RemoteId: model.NewId(), Topics: " bogus incident "}, + {DisplayName: "GGGG Inc", CreatorId: model.NewId(), SiteURL: "gggg.com", RemoteId: model.NewId(), Topics: "*"}, + } + for _, item := range rcData { + _, err := ss.RemoteCluster().Save(item) + require.Nil(t, err) + } + + testData := []struct { + topic string + expectedCount int + expectError bool + }{ + {topic: "", expectedCount: 7, expectError: false}, + {topic: " ", expectedCount: 0, expectError: true}, + {topic: "share", expectedCount: 4}, + {topic: " share ", expectedCount: 4}, + {topic: "bogus", expectedCount: 3}, + {topic: "non-existent", expectedCount: 1}, + {topic: "*", expectedCount: 0, expectError: true}, // can't query with wildcard + } + + for _, tt := range testData { + filter := model.RemoteClusterQueryFilter{ + Topic: tt.topic, + } + list, err := ss.RemoteCluster().GetAll(filter) + if tt.expectError { + assert.Errorf(t, err, "expected error for topic=%s", tt.topic) + } else { + assert.NoErrorf(t, err, "expected no error for topic=%s", tt.topic) + } + assert.Lenf(t, list, tt.expectedCount, "topic=%s", tt.topic) + } +} + +func testRemoteClusterUpdateTopics(t *testing.T, ss store.Store) { + remoteId := model.NewId() + rc := &model.RemoteCluster{ + DisplayName: "Blap Inc", + SiteURL: "blap.com", + RemoteId: remoteId, + Topics: "", + CreatorId: model.NewId(), + } + + _, err := ss.RemoteCluster().Save(rc) + require.Nil(t, err) + + testData := []struct { + topics string + expected string + }{ + {topics: "", expected: ""}, + {topics: " ", expected: ""}, + {topics: "share", expected: " share "}, + {topics: " share ", expected: " share "}, + {topics: "share incident", expected: " share incident "}, + {topics: " share incident ", expected: " share incident "}, + } + + for _, tt := range testData { + _, err = ss.RemoteCluster().UpdateTopics(remoteId, tt.topics) + require.NoError(t, err) + + rcUpdated, err := ss.RemoteCluster().Get(remoteId) + require.NoError(t, err) + + require.Equal(t, tt.expected, rcUpdated.Topics) + } +} + +func clearRemoteClusters(ss store.Store) error { + list, err := ss.RemoteCluster().GetAll(model.RemoteClusterQueryFilter{}) + if err != nil { + return err + } + + for _, rc := range list { + if _, err := ss.RemoteCluster().Delete(rc.RemoteId); err != nil { + return err + } + } + return nil +} diff --git a/store/storetest/shared_channel_store.go b/store/storetest/shared_channel_store.go new file mode 100644 index 00000000000..c8c8ce022ef --- /dev/null +++ b/store/storetest/shared_channel_store.go @@ -0,0 +1,1077 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package storetest + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/store" +) + +func TestSharedChannelStore(t *testing.T, ss store.Store, s SqlStore) { + t.Run("SaveSharedChannel", func(t *testing.T) { testSaveSharedChannel(t, ss) }) + t.Run("GetSharedChannel", func(t *testing.T) { testGetSharedChannel(t, ss) }) + t.Run("HasSharedChannel", func(t *testing.T) { testHasSharedChannel(t, ss) }) + t.Run("GetSharedChannels", func(t *testing.T) { testGetSharedChannels(t, ss) }) + t.Run("UpdateSharedChannel", func(t *testing.T) { testUpdateSharedChannel(t, ss) }) + t.Run("DeleteSharedChannel", func(t *testing.T) { testDeleteSharedChannel(t, ss) }) + + t.Run("SaveSharedChannelRemote", func(t *testing.T) { testSaveSharedChannelRemote(t, ss) }) + t.Run("UpdateSharedChannelRemote", func(t *testing.T) { testUpdateSharedChannelRemote(t, ss) }) + t.Run("GetSharedChannelRemote", func(t *testing.T) { testGetSharedChannelRemote(t, ss) }) + t.Run("GetSharedChannelRemoteByIds", func(t *testing.T) { testGetSharedChannelRemoteByIds(t, ss) }) + t.Run("GetSharedChannelRemotes", func(t *testing.T) { testGetSharedChannelRemotes(t, ss) }) + t.Run("HasRemote", func(t *testing.T) { testHasRemote(t, ss) }) + t.Run("GetRemoteForUser", func(t *testing.T) { testGetRemoteForUser(t, ss) }) + t.Run("UpdateSharedChannelRemoteNextSyncAt", func(t *testing.T) { testUpdateSharedChannelRemoteNextSyncAt(t, ss) }) + t.Run("DeleteSharedChannelRemote", func(t *testing.T) { testDeleteSharedChannelRemote(t, ss) }) + + t.Run("SaveSharedChannelUser", func(t *testing.T) { testSaveSharedChannelUser(t, ss) }) + t.Run("GetSharedChannelUser", func(t *testing.T) { testGetSharedChannelUser(t, ss) }) + t.Run("UpdateSharedChannelUserLastSyncAt", func(t *testing.T) { testUpdateSharedChannelUserLastSyncAt(t, ss) }) + + t.Run("SaveSharedChannelAttachment", func(t *testing.T) { testSaveSharedChannelAttachment(t, ss) }) + t.Run("UpsertSharedChannelAttachment", func(t *testing.T) { testUpsertSharedChannelAttachment(t, ss) }) + t.Run("GetSharedChannelAttachment", func(t *testing.T) { testGetSharedChannelAttachment(t, ss) }) + t.Run("UpdateSharedChannelAttachmentLastSyncAt", func(t *testing.T) { testUpdateSharedChannelAttachmentLastSyncAt(t, ss) }) +} + +func testSaveSharedChannel(t *testing.T, ss store.Store) { + t.Run("Save shared channel (home)", func(t *testing.T) { + channel, err := createTestChannel(ss, "test_save") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + Home: true, + } + + scSaved, err := ss.SharedChannel().Save(sc) + require.NoError(t, err, "couldn't save shared channel") + + require.Equal(t, sc.ChannelId, scSaved.ChannelId) + require.Equal(t, sc.TeamId, scSaved.TeamId) + require.Equal(t, sc.CreatorId, scSaved.CreatorId) + + // ensure channel's Shared flag is set + channelMod, err := ss.Channel().Get(channel.Id, false) + require.NoError(t, err) + require.True(t, channelMod.IsShared()) + }) + + t.Run("Save shared channel (remote)", func(t *testing.T) { + channel, err := createTestChannel(ss, "test_save2") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + RemoteId: model.NewId(), + } + + scSaved, err := ss.SharedChannel().Save(sc) + require.Nil(t, err, "couldn't save shared channel", err) + + require.Equal(t, sc.ChannelId, scSaved.ChannelId) + require.Equal(t, sc.TeamId, scSaved.TeamId) + require.Equal(t, sc.CreatorId, scSaved.CreatorId) + + // ensure channel's Shared flag is set + channelMod, err := ss.Channel().Get(channel.Id, false) + require.NoError(t, err) + require.True(t, channelMod.IsShared()) + }) + + t.Run("Save invalid shared channel", func(t *testing.T) { + sc := &model.SharedChannel{ + ChannelId: "", + TeamId: model.NewId(), + CreatorId: model.NewId(), + ShareName: "testshare", + Home: true, + } + + _, err := ss.SharedChannel().Save(sc) + require.NotNil(t, err, "should error saving invalid shared channel", err) + }) + + t.Run("Save with invalid channel id", func(t *testing.T) { + sc := &model.SharedChannel{ + ChannelId: model.NewId(), + TeamId: model.NewId(), + CreatorId: model.NewId(), + ShareName: "testshare", + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().Save(sc) + require.Error(t, err, "expected error for invalid channel id") + }) +} + +func testGetSharedChannel(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_get") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + Home: true, + } + + scSaved, err := ss.SharedChannel().Save(sc) + require.Nil(t, err, "couldn't save shared channel", err) + + t.Run("Get existing shared channel", func(t *testing.T) { + sc, err := ss.SharedChannel().Get(scSaved.ChannelId) + require.Nil(t, err, "couldn't get shared channel", err) + + require.Equal(t, sc.ChannelId, scSaved.ChannelId) + require.Equal(t, sc.TeamId, scSaved.TeamId) + require.Equal(t, sc.CreatorId, scSaved.CreatorId) + }) + + t.Run("Get non-existent shared channel", func(t *testing.T) { + sc, err := ss.SharedChannel().Get(model.NewId()) + require.NotNil(t, err) + require.Nil(t, sc) + }) +} + +func testHasSharedChannel(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_get") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + Home: true, + } + + scSaved, err := ss.SharedChannel().Save(sc) + require.NoError(t, err, "couldn't save shared channel", err) + + t.Run("Get existing shared channel", func(t *testing.T) { + exists, err := ss.SharedChannel().HasChannel(scSaved.ChannelId) + require.NoError(t, err, "couldn't get shared channel", err) + assert.True(t, exists) + }) + + t.Run("Get non-existent shared channel", func(t *testing.T) { + exists, err := ss.SharedChannel().HasChannel(model.NewId()) + require.NoError(t, err) + assert.False(t, exists) + }) +} + +func testGetSharedChannels(t *testing.T, ss store.Store) { + clearSharedChannels(ss) + + creator := model.NewId() + team1 := model.NewId() + team2 := model.NewId() + rid := model.NewId() + + data := []model.SharedChannel{ + {CreatorId: creator, TeamId: team1, ShareName: "test1", Home: true}, + {CreatorId: creator, TeamId: team1, ShareName: "test2", Home: false, RemoteId: rid}, + {CreatorId: creator, TeamId: team1, ShareName: "test3", Home: false, RemoteId: rid}, + {CreatorId: creator, TeamId: team1, ShareName: "test4", Home: true}, + {CreatorId: creator, TeamId: team2, ShareName: "test5", Home: true}, + {CreatorId: creator, TeamId: team2, ShareName: "test6", Home: false, RemoteId: rid}, + {CreatorId: creator, TeamId: team2, ShareName: "test7", Home: false, RemoteId: rid}, + {CreatorId: creator, TeamId: team2, ShareName: "test8", Home: true}, + {CreatorId: creator, TeamId: team2, ShareName: "test9", Home: true}, + } + + for i, sc := range data { + channel, err := createTestChannel(ss, "test_get2_"+strconv.Itoa(i)) + require.Nil(t, err) + + sc.ChannelId = channel.Id + + _, err = ss.SharedChannel().Save(&sc) + require.Nil(t, err, "error saving shared channel") + } + + t.Run("Get shared channels home only", func(t *testing.T) { + opts := model.SharedChannelFilterOpts{ + ExcludeRemote: true, + CreatorId: creator, + } + + count, err := ss.SharedChannel().GetAllCount(opts) + require.Nil(t, err, "error getting shared channels count") + + home, err := ss.SharedChannel().GetAll(0, 100, opts) + require.Nil(t, err, "error getting shared channels") + + require.Equal(t, int(count), len(home)) + require.Len(t, home, 5, "should be 5 home channels") + for _, sc := range home { + require.True(t, sc.Home, "should be home channel") + } + }) + + t.Run("Get shared channels remote only", func(t *testing.T) { + opts := model.SharedChannelFilterOpts{ + ExcludeHome: true, + } + + count, err := ss.SharedChannel().GetAllCount(opts) + require.Nil(t, err, "error getting shared channels count") + + remotes, err := ss.SharedChannel().GetAll(0, 100, opts) + require.Nil(t, err, "error getting shared channels") + + require.Equal(t, int(count), len(remotes)) + require.Len(t, remotes, 4, "should be 4 remote channels") + for _, sc := range remotes { + require.False(t, sc.Home, "should be remote channel") + } + }) + + t.Run("Get shared channels bad opts", func(t *testing.T) { + opts := model.SharedChannelFilterOpts{ + ExcludeHome: true, + ExcludeRemote: true, + } + _, err := ss.SharedChannel().GetAll(0, 100, opts) + require.NotNil(t, err, "error expected") + }) + + t.Run("Get shared channels by team", func(t *testing.T) { + opts := model.SharedChannelFilterOpts{ + TeamId: team1, + } + + count, err := ss.SharedChannel().GetAllCount(opts) + require.Nil(t, err, "error getting shared channels count") + + remotes, err := ss.SharedChannel().GetAll(0, 100, opts) + require.Nil(t, err, "error getting shared channels") + + require.Equal(t, int(count), len(remotes)) + require.Len(t, remotes, 4, "should be 4 matching channels") + for _, sc := range remotes { + require.Equal(t, team1, sc.TeamId) + } + }) + + t.Run("Get shared channels invalid pagnation", func(t *testing.T) { + opts := model.SharedChannelFilterOpts{ + TeamId: team1, + } + + _, err := ss.SharedChannel().GetAll(-1, 100, opts) + require.NotNil(t, err) + + _, err = ss.SharedChannel().GetAll(0, -100, opts) + require.NotNil(t, err) + }) +} + +func testUpdateSharedChannel(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_update") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + Home: true, + } + + scSaved, err := ss.SharedChannel().Save(sc) + require.Nil(t, err, "couldn't save shared channel", err) + + t.Run("Update existing shared channel", func(t *testing.T) { + id := model.NewId() + scMod := scSaved // copy struct (contains basic types only) + scMod.ShareName = "newname" + scMod.ShareDisplayName = "For testing" + scMod.ShareHeader = "This is a header." + scMod.RemoteId = id + + scUpdated, err := ss.SharedChannel().Update(scMod) + require.Nil(t, err, "couldn't update shared channel", err) + + require.Equal(t, "newname", scUpdated.ShareName) + require.Equal(t, "For testing", scUpdated.ShareDisplayName) + require.Equal(t, "This is a header.", scUpdated.ShareHeader) + require.Equal(t, id, scUpdated.RemoteId) + }) + + t.Run("Update non-existent shared channel", func(t *testing.T) { + sc := &model.SharedChannel{ + ChannelId: model.NewId(), + TeamId: model.NewId(), + CreatorId: model.NewId(), + ShareName: "missingshare", + } + _, err := ss.SharedChannel().Update(sc) + require.NotNil(t, err, "should error when updating non-existent shared channel", err) + }) +} + +func testDeleteSharedChannel(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_delete") + require.Nil(t, err) + + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: model.NewId(), + ShareName: "testshare", + RemoteId: model.NewId(), + } + + _, err = ss.SharedChannel().Save(sc) + require.Nil(t, err, "couldn't save shared channel", err) + + // add some remotes + for i := 0; i < 10; i++ { + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "remote_" + strconv.Itoa(i), + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + _, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "couldn't add remote", err) + } + + t.Run("Delete existing shared channel", func(t *testing.T) { + deleted, err := ss.SharedChannel().Delete(channel.Id) + require.Nil(t, err, "delete existing shared channel should not error", err) + require.True(t, deleted, "expected true from delete shared channel") + + sc, err := ss.SharedChannel().Get(channel.Id) + require.NotNil(t, err) + require.Nil(t, sc) + + // make sure the remotes were deleted. + remotes, err := ss.SharedChannel().GetRemotes(model.SharedChannelRemoteFilterOpts{ChannelId: channel.Id}) + require.Nil(t, err) + require.Len(t, remotes, 0, "expected empty remotes list") + + // ensure channel's Shared flag is unset + channelMod, err := ss.Channel().Get(channel.Id, false) + require.NoError(t, err) + require.False(t, channelMod.IsShared()) + }) + + t.Run("Delete non-existent shared channel", func(t *testing.T) { + deleted, err := ss.SharedChannel().Delete(model.NewId()) + require.Nil(t, err, "delete non-existent shared channel should not error", err) + require.False(t, deleted, "expected false from delete shared channel") + }) +} + +func testSaveSharedChannelRemote(t *testing.T, ss store.Store) { + t.Run("Save shared channel remote", func(t *testing.T) { + channel, err := createTestChannel(ss, "test_save_remote") + require.Nil(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "couldn't save shared channel remote", err) + + require.Equal(t, remote.ChannelId, remoteSaved.ChannelId) + require.Equal(t, remote.CreatorId, remoteSaved.CreatorId) + }) + + t.Run("Save invalid shared channel remote", func(t *testing.T) { + remote := &model.SharedChannelRemote{ + ChannelId: "", + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().SaveRemote(remote) + require.NotNil(t, err, "should error saving invalid remote", err) + }) + + t.Run("Save shared channel remote with invalid channel id", func(t *testing.T) { + remote := &model.SharedChannelRemote{ + ChannelId: model.NewId(), + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().SaveRemote(remote) + require.Error(t, err, "expected error for invalid channel id") + }) +} + +func testUpdateSharedChannelRemote(t *testing.T, ss store.Store) { + t.Run("Update shared channel remote", func(t *testing.T) { + channel, err := createTestChannel(ss, "test_update_remote") + require.Nil(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote_update", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "couldn't save shared channel remote", err) + + remoteSaved.IsInviteAccepted = true + remoteSaved.IsInviteConfirmed = true + remoteSaved.Description = "new_desc" + + remoteUpdated, err := ss.SharedChannel().UpdateRemote(remoteSaved) + require.Nil(t, err, "couldn't update shared channel remote", err) + + require.Equal(t, true, remoteUpdated.IsInviteAccepted) + require.Equal(t, true, remoteUpdated.IsInviteConfirmed) + require.Equal(t, "new_desc", remoteUpdated.Description) + }) + + t.Run("Update invalid shared channel remote", func(t *testing.T) { + remote := &model.SharedChannelRemote{ + ChannelId: "", + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().UpdateRemote(remote) + require.NotNil(t, err, "should error updating invalid remote", err) + }) + + t.Run("Update shared channel remote with invalid channel id", func(t *testing.T) { + remote := &model.SharedChannelRemote{ + ChannelId: model.NewId(), + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().UpdateRemote(remote) + require.Error(t, err, "expected error for invalid channel id") + }) +} + +func testGetSharedChannelRemote(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remote_get") + require.Nil(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "couldn't save remote", err) + + t.Run("Get existing shared channel remote", func(t *testing.T) { + r, err := ss.SharedChannel().GetRemote(remoteSaved.Id) + require.Nil(t, err, "could not get shared channel remote", err) + + require.Equal(t, remoteSaved.Id, r.Id) + require.Equal(t, remoteSaved.ChannelId, r.ChannelId) + require.Equal(t, remoteSaved.Description, r.Description) + require.Equal(t, remoteSaved.CreatorId, r.CreatorId) + require.Equal(t, remoteSaved.RemoteId, r.RemoteId) + }) + + t.Run("Get non-existent shared channel remote", func(t *testing.T) { + r, err := ss.SharedChannel().GetRemote(model.NewId()) + require.NotNil(t, err) + require.Nil(t, r) + }) +} + +func testGetSharedChannelRemoteByIds(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remote_get_by_ids") + require.Nil(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote_by_ids", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "could not save remote", err) + + t.Run("Get existing shared channel remote by ids", func(t *testing.T) { + r, err := ss.SharedChannel().GetRemoteByIds(remoteSaved.ChannelId, remoteSaved.RemoteId) + require.Nil(t, err, "couldn't get shared channel remote by ids", err) + + require.Equal(t, remoteSaved.Id, r.Id) + require.Equal(t, remoteSaved.ChannelId, r.ChannelId) + require.Equal(t, remoteSaved.Description, r.Description) + require.Equal(t, remoteSaved.CreatorId, r.CreatorId) + require.Equal(t, remoteSaved.RemoteId, r.RemoteId) + }) + + t.Run("Get non-existent shared channel remote by ids", func(t *testing.T) { + r, err := ss.SharedChannel().GetRemoteByIds(model.NewId(), model.NewId()) + require.NotNil(t, err) + require.Nil(t, r) + }) +} + +func testGetSharedChannelRemotes(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remotes_get2") + require.Nil(t, err) + + creator := model.NewId() + remoteId := model.NewId() + + data := []model.SharedChannelRemote{ + {ChannelId: channel.Id, CreatorId: creator, Description: "r1", RemoteId: model.NewId(), IsInviteConfirmed: true}, + {ChannelId: channel.Id, CreatorId: creator, Description: "r2", RemoteId: model.NewId(), IsInviteConfirmed: true}, + {ChannelId: channel.Id, CreatorId: creator, Description: "r3", RemoteId: model.NewId(), IsInviteConfirmed: true}, + {CreatorId: creator, Description: "r4", RemoteId: remoteId, IsInviteConfirmed: true}, + {CreatorId: creator, Description: "r5", RemoteId: remoteId, IsInviteConfirmed: true}, + {CreatorId: creator, Description: "r6", RemoteId: remoteId}, + } + + for i, r := range data { + if r.ChannelId == "" { + c, err := createTestChannel(ss, "test_remotes_get2_"+strconv.Itoa(i)) + require.Nil(t, err) + r.ChannelId = c.Id + } + _, err := ss.SharedChannel().SaveRemote(&r) + require.Nil(t, err, "error saving shared channel remote") + } + + t.Run("Get shared channel remotes by channel_id", func(t *testing.T) { + opts := model.SharedChannelRemoteFilterOpts{ + ChannelId: channel.Id, + } + remotes, err := ss.SharedChannel().GetRemotes(opts) + require.Nil(t, err, "should not error", err) + require.Len(t, remotes, 3) + for _, r := range remotes { + require.Contains(t, []string{"r1", "r2", "r3"}, r.Description) + } + }) + + t.Run("Get shared channel remotes by invalid channel_id", func(t *testing.T) { + opts := model.SharedChannelRemoteFilterOpts{ + ChannelId: model.NewId(), + } + remotes, err := ss.SharedChannel().GetRemotes(opts) + require.Nil(t, err, "should not error", err) + require.Len(t, remotes, 0) + }) + + t.Run("Get shared channel remotes by remote_id", func(t *testing.T) { + opts := model.SharedChannelRemoteFilterOpts{ + RemoteId: remoteId, + } + remotes, err := ss.SharedChannel().GetRemotes(opts) + require.Nil(t, err, "should not error", err) + require.Len(t, remotes, 2) // only confirmed invitations + for _, r := range remotes { + require.Contains(t, []string{"r4", "r5"}, r.Description) + } + }) + + t.Run("Get shared channel remotes by invalid remote_id", func(t *testing.T) { + opts := model.SharedChannelRemoteFilterOpts{ + RemoteId: model.NewId(), + } + remotes, err := ss.SharedChannel().GetRemotes(opts) + require.Nil(t, err, "should not error", err) + require.Len(t, remotes, 0) + }) + + t.Run("Get shared channel remotes by remote_id including unconfirmed", func(t *testing.T) { + opts := model.SharedChannelRemoteFilterOpts{ + RemoteId: remoteId, + InclUnconfirmed: true, + } + remotes, err := ss.SharedChannel().GetRemotes(opts) + require.Nil(t, err, "should not error", err) + require.Len(t, remotes, 3) // only confirmed invitations + for _, r := range remotes { + require.Contains(t, []string{"r4", "r5", "r6"}, r.Description) + } + }) +} + +func testHasRemote(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remotes_get2") + require.Nil(t, err) + + remote1 := model.NewId() + remote2 := model.NewId() + + creator := model.NewId() + data := []model.SharedChannelRemote{ + {ChannelId: channel.Id, CreatorId: creator, Description: "r1", RemoteId: remote1}, + {ChannelId: channel.Id, CreatorId: creator, Description: "r2", RemoteId: remote2}, + } + + for _, r := range data { + _, err := ss.SharedChannel().SaveRemote(&r) + require.Nil(t, err, "error saving shared channel remote") + } + + t.Run("has remote", func(t *testing.T) { + has, err := ss.SharedChannel().HasRemote(channel.Id, remote1) + require.NoError(t, err) + assert.True(t, has) + + has, err = ss.SharedChannel().HasRemote(channel.Id, remote2) + require.NoError(t, err) + assert.True(t, has) + }) + + t.Run("wrong channel id ", func(t *testing.T) { + has, err := ss.SharedChannel().HasRemote(model.NewId(), remote1) + require.NoError(t, err) + assert.False(t, has) + }) + + t.Run("wrong remote id", func(t *testing.T) { + has, err := ss.SharedChannel().HasRemote(channel.Id, model.NewId()) + require.NoError(t, err) + assert.False(t, has) + }) +} + +func testGetRemoteForUser(t *testing.T, ss store.Store) { + // add remotes, and users to simulated shared channels. + teamId := model.NewId() + channel, err := createSharedTestChannel(ss, "share_test_channel", true) + require.NoError(t, err) + remotes := []*model.RemoteCluster{ + {RemoteId: model.NewId(), SiteURL: model.NewId(), CreatorId: model.NewId(), RemoteTeamId: teamId, DisplayName: "Test Remote 1"}, + {RemoteId: model.NewId(), SiteURL: model.NewId(), CreatorId: model.NewId(), RemoteTeamId: teamId, DisplayName: "Test Remote 2"}, + {RemoteId: model.NewId(), SiteURL: model.NewId(), CreatorId: model.NewId(), RemoteTeamId: teamId, DisplayName: "Test Remote 3"}, + } + var channelRemotes []*model.SharedChannelRemote + for _, rc := range remotes { + _, err := ss.RemoteCluster().Save(rc) + require.NoError(t, err) + + scr := &model.SharedChannelRemote{Id: model.NewId(), CreatorId: rc.CreatorId, ChannelId: channel.Id, RemoteId: rc.RemoteId} + scr, err = ss.SharedChannel().SaveRemote(scr) + require.NoError(t, err) + channelRemotes = append(channelRemotes, scr) + } + users := []string{model.NewId(), model.NewId(), model.NewId()} + for _, id := range users { + member := &model.ChannelMember{ + ChannelId: channel.Id, + UserId: id, + NotifyProps: model.GetDefaultChannelNotifyProps(), + SchemeGuest: false, + SchemeUser: true, + } + _, err := ss.Channel().SaveMember(member) + require.NoError(t, err) + } + + t.Run("user is member", func(t *testing.T) { + for _, rc := range remotes { + for _, userId := range users { + rcFound, err := ss.SharedChannel().GetRemoteForUser(rc.RemoteId, userId) + assert.NoError(t, err, "remote should be found for user") + assert.Equal(t, rc.RemoteId, rcFound.RemoteId, "remoteIds should match") + } + } + }) + + t.Run("user is not a member", func(t *testing.T) { + for _, rc := range remotes { + rcFound, err := ss.SharedChannel().GetRemoteForUser(rc.RemoteId, model.NewId()) + assert.Error(t, err, "remote should not be found for user") + assert.Nil(t, rcFound) + } + }) + + t.Run("unknown remote id", func(t *testing.T) { + rcFound, err := ss.SharedChannel().GetRemoteForUser(model.NewId(), users[0]) + assert.Error(t, err, "remote should not be found for unknown remote id") + assert.Nil(t, rcFound) + }) +} + +func testUpdateSharedChannelRemoteNextSyncAt(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remote_update_next_sync_at") + require.NoError(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.NoError(t, err, "couldn't save remote", err) + + future := model.GetMillis() + 3600000 // 1 hour in the future + + t.Run("Update NextSyncAt for remote", func(t *testing.T) { + err := ss.SharedChannel().UpdateRemoteNextSyncAt(remoteSaved.Id, future) + require.Nil(t, err, "update NextSyncAt should not error", err) + + r, err := ss.SharedChannel().GetRemote(remoteSaved.Id) + require.NoError(t, err) + require.Equal(t, future, r.NextSyncAt) + }) + + t.Run("Update NextSyncAt for non-existent shared channel remote", func(t *testing.T) { + err := ss.SharedChannel().UpdateRemoteNextSyncAt(model.NewId(), future) + require.Error(t, err, "update non-existent remote should error", err) + }) +} + +func testDeleteSharedChannelRemote(t *testing.T, ss store.Store) { + channel, err := createTestChannel(ss, "test_remote_delete") + require.NoError(t, err) + + remote := &model.SharedChannelRemote{ + ChannelId: channel.Id, + Description: "test_remote", + CreatorId: model.NewId(), + RemoteId: model.NewId(), + } + + remoteSaved, err := ss.SharedChannel().SaveRemote(remote) + require.Nil(t, err, "couldn't save remote", err) + + t.Run("Delete existing shared channel remote", func(t *testing.T) { + deleted, err := ss.SharedChannel().DeleteRemote(remoteSaved.Id) + require.Nil(t, err, "delete existing remote should not error", err) + require.True(t, deleted, "expected true from delete remote") + + r, err := ss.SharedChannel().GetRemote(remoteSaved.Id) + require.NotNil(t, err) + require.Nil(t, r) + }) + + t.Run("Delete non-existent shared channel remote", func(t *testing.T) { + deleted, err := ss.SharedChannel().DeleteRemote(model.NewId()) + require.Nil(t, err, "delete non-existent remote should not error", err) + require.False(t, deleted, "expected false from delete remote") + }) +} + +func createTestChannel(ss store.Store, name string) (*model.Channel, error) { + channel, err := createSharedTestChannel(ss, name, false) + return channel, err +} + +func createSharedTestChannel(ss store.Store, name string, shared bool) (*model.Channel, error) { + channel := &model.Channel{ + TeamId: model.NewId(), + Type: model.CHANNEL_OPEN, + Name: name, + DisplayName: name + " display name", + Header: name + " header", + Purpose: name + "purpose", + CreatorId: model.NewId(), + Shared: model.NewBool(shared), + } + channel, err := ss.Channel().Save(channel, 10000) + if err != nil { + return nil, err + } + + if shared { + sc := &model.SharedChannel{ + ChannelId: channel.Id, + TeamId: channel.TeamId, + CreatorId: channel.CreatorId, + ShareName: channel.Name, + Home: true, + } + _, err = ss.SharedChannel().Save(sc) + if err != nil { + return nil, err + } + } + return channel, nil +} + +func clearSharedChannels(ss store.Store) error { + opts := model.SharedChannelFilterOpts{} + all, err := ss.SharedChannel().GetAll(0, 1000, opts) + if err != nil { + return err + } + + for _, sc := range all { + if _, err := ss.SharedChannel().Delete(sc.ChannelId); err != nil { + return err + } + } + return nil +} + +func testSaveSharedChannelUser(t *testing.T, ss store.Store) { + t.Run("Save shared channel user", func(t *testing.T) { + scUser := &model.SharedChannelUser{ + UserId: model.NewId(), + RemoteId: model.NewId(), + } + + userSaved, err := ss.SharedChannel().SaveUser(scUser) + require.Nil(t, err, "couldn't save shared channel user", err) + + require.Equal(t, scUser.UserId, userSaved.UserId) + require.Equal(t, scUser.RemoteId, userSaved.RemoteId) + }) + + t.Run("Save invalid shared channel user", func(t *testing.T) { + scUser := &model.SharedChannelUser{ + UserId: "", + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().SaveUser(scUser) + require.NotNil(t, err, "should error saving invalid user", err) + }) + + t.Run("Save shared channel user with invalid remote id", func(t *testing.T) { + scUser := &model.SharedChannelUser{ + UserId: model.NewId(), + RemoteId: "bogus", + } + + _, err := ss.SharedChannel().SaveUser(scUser) + require.Error(t, err, "expected error for invalid remote id") + }) +} + +func testGetSharedChannelUser(t *testing.T, ss store.Store) { + scUser := &model.SharedChannelUser{ + UserId: model.NewId(), + RemoteId: model.NewId(), + } + + userSaved, err := ss.SharedChannel().SaveUser(scUser) + require.Nil(t, err, "could not save user", err) + + t.Run("Get existing shared channel user", func(t *testing.T) { + r, err := ss.SharedChannel().GetUser(userSaved.UserId, userSaved.RemoteId) + require.Nil(t, err, "couldn't get shared channel user", err) + + require.Equal(t, userSaved.Id, r.Id) + require.Equal(t, userSaved.UserId, r.UserId) + require.Equal(t, userSaved.RemoteId, r.RemoteId) + require.Equal(t, userSaved.CreateAt, r.CreateAt) + }) + + t.Run("Get non-existent shared channel user", func(t *testing.T) { + u, err := ss.SharedChannel().GetUser(model.NewId(), model.NewId()) + require.NotNil(t, err) + require.Nil(t, u) + }) +} + +func testUpdateSharedChannelUserLastSyncAt(t *testing.T, ss store.Store) { + scUser := &model.SharedChannelUser{ + UserId: model.NewId(), + RemoteId: model.NewId(), + } + + userSaved, err := ss.SharedChannel().SaveUser(scUser) + require.NoError(t, err, "couldn't save user", err) + + future := model.GetMillis() + 3600000 // 1 hour in the future + + t.Run("Update LastSyncAt for user", func(t *testing.T) { + err := ss.SharedChannel().UpdateUserLastSyncAt(userSaved.Id, future) + require.Nil(t, err, "updateLastSyncAt should not error", err) + + u, err := ss.SharedChannel().GetUser(userSaved.UserId, userSaved.RemoteId) + require.NoError(t, err) + require.Equal(t, future, u.LastSyncAt) + }) + + t.Run("Update LastSyncAt for non-existent shared channel user", func(t *testing.T) { + err := ss.SharedChannel().UpdateUserLastSyncAt(model.NewId(), future) + require.Error(t, err, "update non-existent user should error", err) + }) +} + +func testSaveSharedChannelAttachment(t *testing.T, ss store.Store) { + t.Run("Save shared channel attachment", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: model.NewId(), + } + + saved, err := ss.SharedChannel().SaveAttachment(attachment) + require.Nil(t, err, "couldn't save shared channel attachment", err) + + require.Equal(t, attachment.FileId, saved.FileId) + require.Equal(t, attachment.RemoteId, saved.RemoteId) + }) + + t.Run("Save invalid shared channel attachment", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: "", + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().SaveAttachment(attachment) + require.NotNil(t, err, "should error saving invalid attachment", err) + }) + + t.Run("Save shared channel attachment with invalid remote id", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: "bogus", + } + + _, err := ss.SharedChannel().SaveAttachment(attachment) + require.Error(t, err, "expected error for invalid remote id") + }) +} + +func testUpsertSharedChannelAttachment(t *testing.T, ss store.Store) { + t.Run("Upsert new shared channel attachment", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: model.NewId(), + } + + _, err := ss.SharedChannel().UpsertAttachment(attachment) + require.NoError(t, err, "couldn't upsert shared channel attachment", err) + + saved, err := ss.SharedChannel().GetAttachment(attachment.FileId, attachment.RemoteId) + require.NoError(t, err, "couldn't get shared channel attachment", err) + + require.NotZero(t, saved.CreateAt) + require.Equal(t, saved.CreateAt, saved.LastSyncAt) + }) + + t.Run("Upsert existing shared channel attachment", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: model.NewId(), + } + + saved, err := ss.SharedChannel().SaveAttachment(attachment) + require.Nil(t, err, "couldn't save shared channel attachment", err) + + // make sure enough time passed that GetMillis returns a different value + time.Sleep(1 * time.Millisecond) + + _, err = ss.SharedChannel().UpsertAttachment(saved) + require.NoError(t, err, "couldn't upsert shared channel attachment", err) + + updated, err := ss.SharedChannel().GetAttachment(attachment.FileId, attachment.RemoteId) + require.NoError(t, err, "couldn't get shared channel attachment", err) + + require.NotZero(t, updated.CreateAt) + require.Greater(t, updated.LastSyncAt, updated.CreateAt) + }) + + t.Run("Upsert invalid shared channel attachment", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: "", + RemoteId: model.NewId(), + } + + id, err := ss.SharedChannel().UpsertAttachment(attachment) + require.NotNil(t, err, "should error upserting invalid attachment", err) + require.Empty(t, id) + }) + + t.Run("Upsert shared channel attachment with invalid remote id", func(t *testing.T) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: "bogus", + } + + id, err := ss.SharedChannel().UpsertAttachment(attachment) + require.Error(t, err, "expected error for invalid remote id") + require.Empty(t, id) + }) +} + +func testGetSharedChannelAttachment(t *testing.T, ss store.Store) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: model.NewId(), + } + + saved, err := ss.SharedChannel().SaveAttachment(attachment) + require.Nil(t, err, "could not save attachment", err) + + t.Run("Get existing shared channel attachment", func(t *testing.T) { + r, err := ss.SharedChannel().GetAttachment(saved.FileId, saved.RemoteId) + require.Nil(t, err, "couldn't get shared channel attachment", err) + + require.Equal(t, saved.Id, r.Id) + require.Equal(t, saved.FileId, r.FileId) + require.Equal(t, saved.RemoteId, r.RemoteId) + require.Equal(t, saved.CreateAt, r.CreateAt) + }) + + t.Run("Get non-existent shared channel attachment", func(t *testing.T) { + u, err := ss.SharedChannel().GetAttachment(model.NewId(), model.NewId()) + require.NotNil(t, err) + require.Nil(t, u) + }) +} + +func testUpdateSharedChannelAttachmentLastSyncAt(t *testing.T, ss store.Store) { + attachment := &model.SharedChannelAttachment{ + FileId: model.NewId(), + RemoteId: model.NewId(), + } + + saved, err := ss.SharedChannel().SaveAttachment(attachment) + require.NoError(t, err, "couldn't save attachment", err) + + future := model.GetMillis() + 3600000 // 1 hour in the future + + t.Run("Update LastSyncAt for attachment", func(t *testing.T) { + err := ss.SharedChannel().UpdateAttachmentLastSyncAt(saved.Id, future) + require.Nil(t, err, "updateLastSyncAt should not error", err) + + f, err := ss.SharedChannel().GetAttachment(saved.FileId, saved.RemoteId) + require.NoError(t, err) + require.Equal(t, future, f.LastSyncAt) + }) + + t.Run("Update LastSyncAt for non-existent shared channel attachment", func(t *testing.T) { + err := ss.SharedChannel().UpdateAttachmentLastSyncAt(model.NewId(), future) + require.Error(t, err, "update non-existent attachment should error", err) + }) +} diff --git a/store/storetest/store.go b/store/storetest/store.go index 4d6171e7551..58a4ef590b6 100644 --- a/store/storetest/store.go +++ b/store/storetest/store.go @@ -23,6 +23,7 @@ type Store struct { BotStore mocks.BotStore AuditStore mocks.AuditStore ClusterDiscoveryStore mocks.ClusterDiscoveryStore + RemoteClusterStore mocks.RemoteClusterStore ComplianceStore mocks.ComplianceStore SessionStore mocks.SessionStore OAuthStore mocks.OAuthStore @@ -49,6 +50,7 @@ type Store struct { GroupStore mocks.GroupStore UserTermsOfServiceStore mocks.UserTermsOfServiceStore LinkMetadataStore mocks.LinkMetadataStore + SharedChannelStore mocks.SharedChannelStore ProductNoticesStore mocks.ProductNoticesStore context context.Context } @@ -63,6 +65,7 @@ func (s *Store) Bot() store.BotStore { return &s.B func (s *Store) ProductNotices() store.ProductNoticesStore { return &s.ProductNoticesStore } func (s *Store) Audit() store.AuditStore { return &s.AuditStore } func (s *Store) ClusterDiscovery() store.ClusterDiscoveryStore { return &s.ClusterDiscoveryStore } +func (s *Store) RemoteCluster() store.RemoteClusterStore { return &s.RemoteClusterStore } func (s *Store) Compliance() store.ComplianceStore { return &s.ComplianceStore } func (s *Store) Session() store.SessionStore { return &s.SessionStore } func (s *Store) OAuth() store.OAuthStore { return &s.OAuthStore } @@ -89,19 +92,20 @@ func (s *Store) UserTermsOfService() store.UserTermsOfServiceStore { return &s.U func (s *Store) ChannelMemberHistory() store.ChannelMemberHistoryStore { return &s.ChannelMemberHistoryStore } -func (s *Store) Group() store.GroupStore { return &s.GroupStore } -func (s *Store) LinkMetadata() store.LinkMetadataStore { return &s.LinkMetadataStore } -func (s *Store) MarkSystemRanUnitTests() { /* do nothing */ } -func (s *Store) Close() { /* do nothing */ } -func (s *Store) LockToMaster() { /* do nothing */ } -func (s *Store) UnlockFromMaster() { /* do nothing */ } -func (s *Store) DropAllTables() { /* do nothing */ } -func (s *Store) GetDbVersion(bool) (string, error) { return "", nil } -func (s *Store) RecycleDBConnections(time.Duration) {} -func (s *Store) TotalMasterDbConnections() int { return 1 } -func (s *Store) TotalReadDbConnections() int { return 1 } -func (s *Store) TotalSearchDbConnections() int { return 1 } -func (s *Store) GetCurrentSchemaVersion() string { return "" } +func (s *Store) Group() store.GroupStore { return &s.GroupStore } +func (s *Store) LinkMetadata() store.LinkMetadataStore { return &s.LinkMetadataStore } +func (s *Store) SharedChannel() store.SharedChannelStore { return &s.SharedChannelStore } +func (s *Store) MarkSystemRanUnitTests() { /* do nothing */ } +func (s *Store) Close() { /* do nothing */ } +func (s *Store) LockToMaster() { /* do nothing */ } +func (s *Store) UnlockFromMaster() { /* do nothing */ } +func (s *Store) DropAllTables() { /* do nothing */ } +func (s *Store) GetDbVersion(bool) (string, error) { return "", nil } +func (s *Store) RecycleDBConnections(time.Duration) {} +func (s *Store) TotalMasterDbConnections() int { return 1 } +func (s *Store) TotalReadDbConnections() int { return 1 } +func (s *Store) TotalSearchDbConnections() int { return 1 } +func (s *Store) GetCurrentSchemaVersion() string { return "" } func (s *Store) CheckIntegrity() <-chan model.IntegrityCheckResult { return make(chan model.IntegrityCheckResult) } @@ -117,6 +121,7 @@ func (s *Store) AssertExpectations(t mock.TestingT) bool { &s.BotStore, &s.AuditStore, &s.ClusterDiscoveryStore, + &s.RemoteClusterStore, &s.ComplianceStore, &s.SessionStore, &s.OAuthStore, @@ -140,5 +145,6 @@ func (s *Store) AssertExpectations(t mock.TestingT) bool { &s.SchemeStore, &s.ThreadStore, &s.ProductNoticesStore, + &s.SharedChannelStore, ) } diff --git a/store/storetest/thread_store.go b/store/storetest/thread_store.go index 44a46c527fc..7a45c11df1e 100644 --- a/store/storetest/thread_store.go +++ b/store/storetest/thread_store.go @@ -145,7 +145,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) { thread1, err := ss.Thread().Get(newPosts[0].RootId) require.NoError(t, err) - rrootPost, err := ss.Post().GetSingle(rootPost.Id) + rrootPost, err := ss.Post().GetSingle(rootPost.Id, false) require.NoError(t, err) require.Equal(t, rrootPost.UpdateAt, rootPost.UpdateAt) @@ -164,7 +164,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) { _, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost2, &replyPost3}) require.NoError(t, err) - rrootPost2, err := ss.Post().GetSingle(rootPost.Id) + rrootPost2, err := ss.Post().GetSingle(rootPost.Id, false) require.NoError(t, err) require.Greater(t, rrootPost2.UpdateAt, rrootPost.UpdateAt) diff --git a/store/timerlayer/timerlayer.go b/store/timerlayer/timerlayer.go index 716434eba32..5329b326c4d 100644 --- a/store/timerlayer/timerlayer.go +++ b/store/timerlayer/timerlayer.go @@ -38,9 +38,11 @@ type TimerLayer struct { PreferenceStore store.PreferenceStore ProductNoticesStore store.ProductNoticesStore ReactionStore store.ReactionStore + RemoteClusterStore store.RemoteClusterStore RoleStore store.RoleStore SchemeStore store.SchemeStore SessionStore store.SessionStore + SharedChannelStore store.SharedChannelStore StatusStore store.StatusStore SystemStore store.SystemStore TeamStore store.TeamStore @@ -134,6 +136,10 @@ func (s *TimerLayer) Reaction() store.ReactionStore { return s.ReactionStore } +func (s *TimerLayer) RemoteCluster() store.RemoteClusterStore { + return s.RemoteClusterStore +} + func (s *TimerLayer) Role() store.RoleStore { return s.RoleStore } @@ -146,6 +152,10 @@ func (s *TimerLayer) Session() store.SessionStore { return s.SessionStore } +func (s *TimerLayer) SharedChannel() store.SharedChannelStore { + return s.SharedChannelStore +} + func (s *TimerLayer) Status() store.StatusStore { return s.StatusStore } @@ -290,6 +300,11 @@ type TimerLayerReactionStore struct { Root *TimerLayer } +type TimerLayerRemoteClusterStore struct { + store.RemoteClusterStore + Root *TimerLayer +} + type TimerLayerRoleStore struct { store.RoleStore Root *TimerLayer @@ -305,6 +320,11 @@ type TimerLayerSessionStore struct { Root *TimerLayer } +type TimerLayerSharedChannelStore struct { + store.SharedChannelStore + Root *TimerLayer +} + type TimerLayerStatusStore struct { store.StatusStore Root *TimerLayer @@ -615,10 +635,10 @@ func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int return result, resultVar1, err } -func (s *TimerLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User) (*model.Channel, error) { +func (s *TimerLayerChannelStore) CreateDirectChannel(userId *model.User, otherUserId *model.User, channelOptions ...model.ChannelOption) (*model.Channel, error) { start := timemodule.Now() - result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId) + result, err := s.ChannelStore.CreateDirectChannel(userId, otherUserId, channelOptions...) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -1415,6 +1435,22 @@ func (s *TimerLayerChannelStore) GetTeamChannels(teamID string) (*model.ChannelL return result, err } +func (s *TimerLayerChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) { + start := timemodule.Now() + + result, err := s.ChannelStore.GetTeamForChannel(channelID) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelStore.GetTeamForChannel", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelStore) GroupSyncedChannelCount() (int64, error) { start := timemodule.Now() @@ -1920,6 +1956,22 @@ func (s *TimerLayerChannelStore) SetDeleteAt(channelID string, deleteAt int64, u return err } +func (s *TimerLayerChannelStore) SetShared(channelId string, shared bool) error { + start := timemodule.Now() + + err := s.ChannelStore.SetShared(channelId, shared) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelStore.SetShared", success, elapsed) + } + return err +} + func (s *TimerLayerChannelStore) Update(channel *model.Channel) (*model.Channel, error) { start := timemodule.Now() @@ -4796,6 +4848,22 @@ func (s *TimerLayerPostStore) GetPostsSince(options model.GetPostsSinceOptions, return result, err } +func (s *TimerLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { + start := timemodule.Now() + + result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("PostStore.GetPostsSinceForSync", success, elapsed) + } + return result, err +} + func (s *TimerLayerPostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { start := timemodule.Now() @@ -4812,10 +4880,10 @@ func (s *TimerLayerPostStore) GetRepliesForExport(parentID string) ([]*model.Rep return result, err } -func (s *TimerLayerPostStore) GetSingle(id string) (*model.Post, error) { +func (s *TimerLayerPostStore) GetSingle(id string, inclDeleted bool) (*model.Post, error) { start := timemodule.Now() - result, err := s.PostStore.GetSingle(id) + result, err := s.PostStore.GetSingle(id, inclDeleted) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -5275,6 +5343,22 @@ func (s *TimerLayerReactionStore) GetForPost(postID string, allowFromCache bool) return result, err } +func (s *TimerLayerReactionStore) GetForPostSince(postId string, since int64, excludeRemoteId string, inclDeleted bool) ([]*model.Reaction, error) { + start := timemodule.Now() + + result, err := s.ReactionStore.GetForPostSince(postId, since, excludeRemoteId, inclDeleted) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ReactionStore.GetForPostSince", success, elapsed) + } + return result, err +} + func (s *TimerLayerReactionStore) PermanentDeleteBatch(endTime int64, limit int64) (int64, error) { start := timemodule.Now() @@ -5307,6 +5391,118 @@ func (s *TimerLayerReactionStore) Save(reaction *model.Reaction) (*model.Reactio return result, err } +func (s *TimerLayerRemoteClusterStore) Delete(remoteClusterId string) (bool, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.Delete(remoteClusterId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.Delete", success, elapsed) + } + return result, err +} + +func (s *TimerLayerRemoteClusterStore) Get(remoteClusterId string) (*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.Get(remoteClusterId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.Get", success, elapsed) + } + return result, err +} + +func (s *TimerLayerRemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.GetAll(filter) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.GetAll", success, elapsed) + } + return result, err +} + +func (s *TimerLayerRemoteClusterStore) Save(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.Save(rc) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.Save", success, elapsed) + } + return result, err +} + +func (s *TimerLayerRemoteClusterStore) SetLastPingAt(remoteClusterId string) error { + start := timemodule.Now() + + err := s.RemoteClusterStore.SetLastPingAt(remoteClusterId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.SetLastPingAt", success, elapsed) + } + return err +} + +func (s *TimerLayerRemoteClusterStore) Update(rc *model.RemoteCluster) (*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.Update(rc) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.Update", success, elapsed) + } + return result, err +} + +func (s *TimerLayerRemoteClusterStore) UpdateTopics(remoteClusterId string, topics string) (*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.RemoteClusterStore.UpdateTopics(remoteClusterId, topics) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("RemoteClusterStore.UpdateTopics", success, elapsed) + } + return result, err +} + func (s *TimerLayerRoleStore) AllChannelSchemeRoles() ([]*model.Role, error) { start := timemodule.Now() @@ -5850,6 +6046,390 @@ func (s *TimerLayerSessionStore) UpdateRoles(userId string, roles string) (strin return result, err } +func (s *TimerLayerSharedChannelStore) Delete(channelId string) (bool, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.Delete(channelId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.Delete", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) DeleteRemote(remoteId string) (bool, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.DeleteRemote(remoteId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.DeleteRemote", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) Get(channelId string) (*model.SharedChannel, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.Get(channelId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.Get", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetAll(offset int, limit int, opts model.SharedChannelFilterOpts) ([]*model.SharedChannel, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetAll(offset, limit, opts) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetAll", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetAllCount(opts model.SharedChannelFilterOpts) (int64, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetAllCount(opts) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetAllCount", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetAttachment(fileId string, remoteId string) (*model.SharedChannelAttachment, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetAttachment(fileId, remoteId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetAttachment", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetRemote(id string) (*model.SharedChannelRemote, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetRemote(id) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetRemote", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetRemoteByIds(channelId, remoteId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetRemoteByIds", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetRemoteForUser(remoteId, userId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetRemoteForUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetRemotes(opts) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetRemotes", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetRemotesStatus(channelId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetRemotesStatus", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetUser(userId string, remoteId string) (*model.SharedChannelUser, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetUser(userId, remoteId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) HasChannel(channelID string) (bool, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.HasChannel(channelID) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.HasChannel", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) HasRemote(channelID string, remoteId string) (bool, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.HasRemote(channelID, remoteId) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.HasRemote", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) Save(sc *model.SharedChannel) (*model.SharedChannel, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.Save(sc) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.Save", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.SaveAttachment(remote) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.SaveAttachment", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) SaveRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.SaveRemote(remote) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.SaveRemote", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.SaveUser(remote) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.SaveUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) Update(sc *model.SharedChannel) (*model.SharedChannel, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.Update(sc) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.Update", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) UpdateAttachmentLastSyncAt(id string, syncTime int64) error { + start := timemodule.Now() + + err := s.SharedChannelStore.UpdateAttachmentLastSyncAt(id, syncTime) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateAttachmentLastSyncAt", success, elapsed) + } + return err +} + +func (s *TimerLayerSharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (*model.SharedChannelRemote, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.UpdateRemote(remote) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateRemote", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { + start := timemodule.Now() + + err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateRemoteNextSyncAt", success, elapsed) + } + return err +} + +func (s *TimerLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { + start := timemodule.Now() + + err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateUserLastSyncAt", success, elapsed) + } + return err +} + +func (s *TimerLayerSharedChannelStore) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.UpsertAttachment(remote) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpsertAttachment", success, elapsed) + } + return result, err +} + func (s *TimerLayerStatusStore) Get(userId string) (*model.Status, error) { start := timemodule.Now() @@ -9219,9 +9799,11 @@ func New(childStore store.Store, metrics einterfaces.MetricsInterface) *TimerLay newStore.PreferenceStore = &TimerLayerPreferenceStore{PreferenceStore: childStore.Preference(), Root: &newStore} newStore.ProductNoticesStore = &TimerLayerProductNoticesStore{ProductNoticesStore: childStore.ProductNotices(), Root: &newStore} newStore.ReactionStore = &TimerLayerReactionStore{ReactionStore: childStore.Reaction(), Root: &newStore} + newStore.RemoteClusterStore = &TimerLayerRemoteClusterStore{RemoteClusterStore: childStore.RemoteCluster(), Root: &newStore} newStore.RoleStore = &TimerLayerRoleStore{RoleStore: childStore.Role(), Root: &newStore} newStore.SchemeStore = &TimerLayerSchemeStore{SchemeStore: childStore.Scheme(), Root: &newStore} newStore.SessionStore = &TimerLayerSessionStore{SessionStore: childStore.Session(), Root: &newStore} + newStore.SharedChannelStore = &TimerLayerSharedChannelStore{SharedChannelStore: childStore.SharedChannel(), Root: &newStore} newStore.StatusStore = &TimerLayerStatusStore{StatusStore: childStore.Status(), Root: &newStore} newStore.SystemStore = &TimerLayerSystemStore{SystemStore: childStore.System(), Root: &newStore} newStore.TeamStore = &TimerLayerTeamStore{TeamStore: childStore.Team(), Root: &newStore} diff --git a/testlib/cluster.go b/testlib/cluster.go index 25094c4384c..cd17aa7bc64 100644 --- a/testlib/cluster.go +++ b/testlib/cluster.go @@ -74,6 +74,19 @@ func (c *FakeClusterInterface) GetMessages() []*model.ClusterMessage { return c.messages } +func (c *FakeClusterInterface) SelectMessages(filterCond func(message *model.ClusterMessage) bool) []*model.ClusterMessage { + c.mut.RLock() + defer c.mut.RUnlock() + + filteredMessages := []*model.ClusterMessage{} + for _, msg := range c.messages { + if filterCond(msg) { + filteredMessages = append(filteredMessages, msg) + } + } + return filteredMessages +} + func (c *FakeClusterInterface) ClearMessages() { c.mut.Lock() defer c.mut.Unlock() diff --git a/web/context.go b/web/context.go index a7f9692d103..b10da5c7b79 100644 --- a/web/context.go +++ b/web/context.go @@ -131,6 +131,13 @@ func (c *Context) CloudKeyRequired() { } } +func (c *Context) RemoteClusterTokenRequired() { + if license := c.App.Srv().License(); license == nil || !*license.Features.RemoteClusterService || c.App.Session().Props[model.SESSION_PROP_TYPE] != model.SESSION_TYPE_REMOTECLUSTER_TOKEN { + c.Err = model.NewAppError("", "api.context.session_expired.app_error", nil, "TokenRequired", http.StatusUnauthorized) + return + } +} + func (c *Context) MfaRequired() { // Must be licensed for MFA and have it configured for enforcement if license := c.App.Srv().License(); license == nil || !*license.Features.MFA || !*c.App.Config().ServiceSettings.EnableMultifactorAuthentication || !*c.App.Config().ServiceSettings.EnforceMultifactorAuthentication { @@ -209,6 +216,18 @@ func (c *Context) SetServerBusyError() { c.Err = NewServerBusyError() } +func (c *Context) SetInvalidRemoteIdError(id string) { + c.Err = NewInvalidRemoteIdError(id) +} + +func (c *Context) SetInvalidRemoteClusterTokenError() { + c.Err = NewInvalidRemoteClusterTokenError() +} + +func (c *Context) SetJSONEncodingError() { + c.Err = NewJSONEncodingError() +} + func (c *Context) SetCommandNotFoundError() { c.Err = model.NewAppError("GetCommand", "store.sql_command.save.get.app_error", nil, "", http.StatusNotFound) } @@ -246,6 +265,21 @@ func NewServerBusyError() *model.AppError { return err } +func NewInvalidRemoteIdError(parameter string) *model.AppError { + err := model.NewAppError("Context", "api.context.remote_id_invalid.app_error", map[string]interface{}{"RemoteId": parameter}, "", http.StatusBadRequest) + return err +} + +func NewInvalidRemoteClusterTokenError() *model.AppError { + err := model.NewAppError("Context", "api.context.remote_id_invalid.app_error", nil, "", http.StatusUnauthorized) + return err +} + +func NewJSONEncodingError() *model.AppError { + err := model.NewAppError("Context", "api.context.json_encoding.app_error", nil, "", http.StatusInternalServerError) + return err +} + func (c *Context) SetPermissionError(permissions ...*model.Permission) { c.Err = c.App.MakePermissionError(permissions) } @@ -685,3 +719,7 @@ func (c *Context) RequireInvoiceId() *Context { return c } + +func (c *Context) GetRemoteID(r *http.Request) string { + return r.Header.Get(model.HEADER_REMOTECLUSTER_ID) +} diff --git a/web/handlers.go b/web/handlers.go index 15675878f98..8324bbfd487 100644 --- a/web/handlers.go +++ b/web/handlers.go @@ -70,16 +70,17 @@ func (w *Web) NewStaticHandler(h func(*Context, http.ResponseWriter, *http.Reque } type Handler struct { - GetGlobalAppOptions app.AppOptionCreator - HandleFunc func(*Context, http.ResponseWriter, *http.Request) - HandlerName string - RequireSession bool - RequireCloudKey bool - TrustRequester bool - RequireMfa bool - IsStatic bool - IsLocal bool - DisableWhenBusy bool + GetGlobalAppOptions app.AppOptionCreator + HandleFunc func(*Context, http.ResponseWriter, *http.Request) + HandlerName string + RequireSession bool + RequireCloudKey bool + RequireRemoteClusterToken bool + TrustRequester bool + RequireMfa bool + IsStatic bool + IsLocal bool + DisableWhenBusy bool cspShaDirective string } @@ -204,7 +205,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, tokenLocation := app.ParseAuthTokenFromRequest(r) - if token != "" && tokenLocation != app.TokenLocationCloudHeader { + if token != "" && tokenLocation != app.TokenLocationCloudHeader && tokenLocation != app.TokenLocationRemoteClusterHeader { session, err := c.App.GetSession(token) defer app.ReturnSessionToPool(session) @@ -237,6 +238,21 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { c.App.SetSession(session) } + } else if token != "" && c.App.Srv().License() != nil && *c.App.Srv().License().Features.RemoteClusterService && tokenLocation == app.TokenLocationRemoteClusterHeader { + // Get the remote cluster + if remoteId := c.GetRemoteID(r); remoteId == "" { + c.Logger.Warn("Missing remote cluster id") // + c.Err = model.NewAppError("ServeHTTP", "api.context.remote_id_missing.app_error", nil, "", http.StatusUnauthorized) + } else { + // Check the token is correct for the remote cluster id. + session, err := c.App.GetRemoteClusterSession(token, remoteId) + if err != nil { + c.Logger.Warn("Invalid remote cluster token", mlog.Err(err)) + c.Err = err + } else { + c.App.SetSession(session) + } + } } c.Logger = c.App.Log().With( @@ -263,6 +279,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.CloudKeyRequired() } + if c.Err == nil && h.RequireRemoteClusterToken { + c.RemoteClusterTokenRequired() + } + if c.Err == nil && h.IsLocal { // if the connection is local, RemoteAddr shouldn't have the // shape IP:PORT (it will be "@" in Linux, for example)