From 58dce5930ef70562e020385471b4c5903dd32006 Mon Sep 17 00:00:00 2001 From: Martin Raymond Kraft Date: Fri, 12 Mar 2021 12:37:30 -0500 Subject: [PATCH] [MM-31094] Replication Lag (#16888) * MM-31094: Adds tooling to develop and test using a MySQL instance with replication lag. Adds some lazy lookups to fallback to master if results are not found. * MM-31094: Removes mysql-read-replica from default docker services. * MM-31094: Switches (store..SessionStore).Get and (store.TeamStore).GetMember to using context.Context. * MM-31094: Updates (store.UsersStore).Get to use context. * MM-31094: Updates (store.PostStore).Get to use context. * MM-31094: Removes feature flag and config setting. * MM-31094: Rolls back some master reads. * MM-31094: Rolls a non-cache read. * MM-31094: Removes feature flag from the store. * MM-31094: Removes unused constant and struct field. * MM-31094: Removes some old feature flag references. * MM-31094: Fixes some tests. * MM-31094: App layers fix. * MM-31094: Fixes mocks. * MM-31094: Don't reparse flag. * MM-31094: No reparse. * MM-31094: Removed unused FeatureFlags field. * MM-31094: Removes unnecessary feature flags variable declarations. * MM-31094: Fixes copy-paste error. * MM-31094: Fixes logical error. * MM-30194: Removes test method from store. * Revert "MM-30194: Removes test method from store." This reverts commit d5a6e8529bd5f4d993824c828e239d009b05e567. * MM-31094: Conforming to make's strange syntax. * MM-31094: Configures helper for read replica with option. * MM-31094: Adds some missing ctx's. * MM-31094: WIP * MM-31094: Updates test names. * MM-31094: WIP * MM-31094: Removes unnecessary master reads. * MM-31094: ID case changes out of scope. * MM-31094: Removes unused context. * MM-31094: Switches to a helper. Removes some var naming changes. Fixes a merge error. * MM-31094: Removes SQLITE db driver ref. * MM-31094: Layer generate fix. * MM-31094: Removes unnecessary changes. * MM-31094: Moves test method. * MM-31094: Re-add previous fix. * MM-31094: Removes make command for dev. * MM-31094: Fix for login. Co-authored-by: Mattermod --- Makefile | 4 +- api4/apitestlib.go | 4 +- api4/main_test.go | 9 ++ api4/user_test.go | 39 +++++++ app/channel.go | 3 +- app/file.go | 3 +- app/main_test.go | 8 ++ app/post.go | 9 +- app/post_test.go | 40 +++++++ app/server_test.go | 2 +- app/session.go | 13 ++- app/session_test.go | 3 +- app/team.go | 7 +- app/web_hub_test.go | 2 +- build/docker-compose-generator/main.go | 21 ++-- build/docker-compose.common.yml | 22 ++++ build/docker/mysql.conf.d/replica.cnf | 3 + build/docker/mysql.conf.d/source.cnf | 4 + docker-compose.makefile.yml | 8 ++ docker-compose.yaml | 7 ++ scripts/replica-lag-set.sh | 4 + scripts/replica-mysql-config.sh | 32 ++++++ .../searchengine/bleveengine/bleve_test.go | 2 +- store/localcachelayer/layer_test.go | 2 +- store/opentracinglayer/opentracinglayer.go | 12 +- store/retrylayer/retrylayer.go | 12 +- store/searchlayer/layer_test.go | 2 +- store/searchlayer/post_layer.go | 4 +- store/sqlstore/group_store.go | 6 +- store/sqlstore/post_store.go | 4 +- store/sqlstore/session_store.go | 6 +- store/sqlstore/store.go | 24 ++-- store/sqlstore/store_test.go | 6 +- store/sqlstore/team_store.go | 4 +- store/store.go | 6 +- store/storetest/mocks/PostStore.go | 16 +-- store/storetest/mocks/SessionStore.go | 16 +-- store/storetest/mocks/TeamStore.go | 14 +-- store/storetest/oauth_store.go | 3 +- store/storetest/post_store.go | 105 +++++++++--------- store/storetest/reaction_store.go | 15 +-- store/storetest/session_store.go | 37 +++--- store/storetest/settings.go | 35 ++++-- store/storetest/team_store.go | 24 ++-- store/storetest/thread_store.go | 3 +- store/storetest/user_access_token_store.go | 7 +- store/storetest/user_store.go | 28 ++--- store/timerlayer/timerlayer.go | 12 +- testlib/helper.go | 62 ++++++++++- tests/test-data.ldif | 24 +++- 50 files changed, 519 insertions(+), 219 deletions(-) create mode 100644 build/docker/mysql.conf.d/replica.cnf create mode 100644 build/docker/mysql.conf.d/source.cnf create mode 100755 scripts/replica-lag-set.sh create mode 100755 scripts/replica-mysql-config.sh diff --git a/Makefile b/Makefile index 47a9106abde..5c511ad2aba 100644 --- a/Makefile +++ b/Makefile @@ -165,6 +165,9 @@ else ifneq (,$(findstring openldap,$(ENABLED_DOCKER_SERVICES))) cat tests/${LDAP_DATA}-data.ldif | docker-compose -f docker-compose.makefile.yml exec -T openldap bash -c 'ldapadd -x -D "cn=admin,dc=mm,dc=test,dc=com" -w mostest || true'; endif +ifneq (,$(findstring mysql-read-replica,$(ENABLED_DOCKER_SERVICES))) + ./scripts/replica-mysql-config.sh +endif endif run-haserver: run-client @@ -193,7 +196,6 @@ else docker-compose rm -v endif - plugin-checker: $(GO) run $(GOFLAGS) ./plugin/checker diff --git a/api4/apitestlib.go b/api4/apitestlib.go index 8fdfb56efdf..1fdcb80363b 100644 --- a/api4/apitestlib.go +++ b/api4/apitestlib.go @@ -1006,7 +1006,7 @@ func (th *TestHelper) MakeUserChannelAdmin(user *model.User, channel *model.Chan func (th *TestHelper) UpdateUserToTeamAdmin(user *model.User, team *model.Team) { utils.DisableDebugLogForTest() - if tm, err := th.App.Srv().Store.Team().GetMember(team.Id, user.Id); err == nil { + if tm, err := th.App.Srv().Store.Team().GetMember(context.Background(), team.Id, user.Id); err == nil { tm.SchemeAdmin = true if _, err = th.App.Srv().Store.Team().UpdateMember(tm); err != nil { utils.EnableDebugLogForTest() @@ -1023,7 +1023,7 @@ func (th *TestHelper) UpdateUserToTeamAdmin(user *model.User, team *model.Team) func (th *TestHelper) UpdateUserToNonTeamAdmin(user *model.User, team *model.Team) { utils.DisableDebugLogForTest() - if tm, err := th.App.Srv().Store.Team().GetMember(team.Id, user.Id); err == nil { + if tm, err := th.App.Srv().Store.Team().GetMember(context.Background(), team.Id, user.Id); err == nil { tm.SchemeAdmin = false if _, err = th.App.Srv().Store.Team().UpdateMember(tm); err != nil { utils.EnableDebugLogForTest() diff --git a/api4/main_test.go b/api4/main_test.go index 6ba97433ad8..8970384cc24 100644 --- a/api4/main_test.go +++ b/api4/main_test.go @@ -4,16 +4,25 @@ package api4 import ( + "flag" "testing" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/testlib" ) +var replicaFlag bool + func TestMain(m *testing.M) { + if f := flag.Lookup("mysql-replica"); f == nil { + flag.BoolVar(&replicaFlag, "mysql-replica", false, "") + flag.Parse() + } + var options = testlib.HelperOptions{ EnableStore: true, EnableResources: true, + WithReadReplica: replicaFlag, } mlog.DisableZap() diff --git a/api4/user_test.go b/api4/user_test.go index 497f0cc5009..9f91550474a 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -3348,6 +3348,45 @@ func TestLogin(t *testing.T) { }) } +func TestLoginWithLag(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + th.Client.Logout() + + t.Run("with replication lag, caches cleared", func(t *testing.T) { + if !replicaFlag { + t.Skipf("requires test flag: -mysql-replica") + } + + if *th.App.Srv().Config().SqlSettings.DriverName != model.DATABASE_DRIVER_MYSQL { + t.Skipf("requires %q database driver", model.DATABASE_DRIVER_MYSQL) + } + + mainHelper.SQLStore.UpdateLicense(model.NewTestLicense("ldap")) + mainHelper.ToggleReplicasOff() + + err := th.App.RevokeAllSessions(th.BasicUser.Id) + require.Nil(t, err) + + mainHelper.ToggleReplicasOn() + defer mainHelper.ToggleReplicasOff() + + cmdErr := mainHelper.SetReplicationLagForTesting(5) + require.Nil(t, cmdErr) + defer mainHelper.SetReplicationLagForTesting(0) + + _, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password) + CheckNoError(t, resp) + + err = th.App.Srv().InvalidateAllCaches() + require.Nil(t, err) + + session, err := th.App.GetSession(th.Client.AuthToken) + require.Nil(t, err) + require.NotNil(t, session) + }) +} + func TestLoginCookies(t *testing.T) { t.Run("should return cookies with X-Requested-With header", func(t *testing.T) { th := Setup(t).InitBasic() diff --git a/app/channel.go b/app/channel.go index 54655f981d5..93b55e26416 100644 --- a/app/channel.go +++ b/app/channel.go @@ -16,6 +16,7 @@ import ( "github.com/mattermost/mattermost-server/v5/shared/i18n" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/store" + "github.com/mattermost/mattermost-server/v5/store/sqlstore" "github.com/mattermost/mattermost-server/v5/utils" ) @@ -1379,7 +1380,7 @@ func (a *App) addUserToChannel(user *model.User, channel *model.Channel) (*model } func (a *App) AddUserToChannel(user *model.User, channel *model.Channel) (*model.ChannelMember, *model.AppError) { - teamMember, nErr := a.Srv().Store.Team().GetMember(channel.TeamId, user.Id) + teamMember, nErr := a.Srv().Store.Team().GetMember(sqlstore.WithMaster(context.Background()), channel.TeamId, user.Id) if nErr != nil { var nfErr *store.ErrNotFound switch { diff --git a/app/file.go b/app/file.go index d3deaa1ef03..7075ba33002 100644 --- a/app/file.go +++ b/app/file.go @@ -6,6 +6,7 @@ package app import ( "archive/zip" "bytes" + "context" "crypto/sha256" "encoding/base64" "errors" @@ -406,7 +407,7 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo { fileMigrationLock.Lock() defer fileMigrationLock.Unlock() - result, nErr := a.Srv().Store.Post().Get(post.Id, false, false, false) + result, nErr := a.Srv().Store.Post().Get(context.Background(), post.Id, false, false, false) if nErr != nil { mlog.Error("Unable to get post when migrating post to use FileInfos", mlog.Err(nErr), mlog.String("post_id", post.Id)) return []*model.FileInfo{} diff --git a/app/main_test.go b/app/main_test.go index 35fdd6b1bb3..ae664aa78d0 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -4,6 +4,7 @@ package app import ( + "flag" "testing" "github.com/mattermost/mattermost-server/v5/shared/mlog" @@ -11,11 +12,18 @@ import ( ) var mainHelper *testlib.MainHelper +var replicaFlag bool func TestMain(m *testing.M) { + if f := flag.Lookup("mysql-replica"); f == nil { + flag.BoolVar(&replicaFlag, "mysql-replica", false, "") + flag.Parse() + } + var options = testlib.HelperOptions{ EnableStore: true, EnableResources: true, + WithReadReplica: replicaFlag, } mlog.DisableZap() diff --git a/app/post.go b/app/post.go index 20b87512c94..d9c256146cf 100644 --- a/app/post.go +++ b/app/post.go @@ -19,6 +19,7 @@ import ( "github.com/mattermost/mattermost-server/v5/shared/i18n" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/store" + "github.com/mattermost/mattermost-server/v5/store/sqlstore" ) const ( @@ -186,7 +187,7 @@ func (a *App) CreatePost(post *model.Post, channel *model.Channel, triggerWebhoo if post.RootId != "" { pchan = make(chan store.StoreResult, 1) go func() { - r, pErr := a.Srv().Store.Post().Get(post.RootId, false, false, false) + r, pErr := a.Srv().Store.Post().Get(sqlstore.WithMaster(context.Background()), post.RootId, false, false, false) pchan <- store.StoreResult{Data: r, NErr: pErr} close(pchan) }() @@ -537,7 +538,7 @@ func (a *App) DeleteEphemeralPost(userID, postID string) { func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model.AppError) { post.SanitizeProps() - postLists, nErr := a.Srv().Store.Post().Get(post.Id, false, false, false) + postLists, nErr := a.Srv().Store.Post().Get(context.Background(), post.Id, false, false, false) if nErr != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput @@ -742,7 +743,7 @@ func (a *App) GetSinglePost(postID string) (*model.Post, *model.AppError) { } func (a *App) GetPostThread(postID string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) { - posts, err := a.Srv().Store.Post().Get(postID, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + posts, err := a.Srv().Store.Post().Get(context.Background(), postID, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput @@ -787,7 +788,7 @@ func (a *App) GetFlaggedPostsForChannel(userID, channelID string, offset int, li } func (a *App) GetPermalinkPost(postID string, userID string) (*model.PostList, *model.AppError) { - list, nErr := a.Srv().Store.Post().Get(postID, false, false, false) + list, nErr := a.Srv().Store.Post().Get(context.Background(), postID, false, false, false) if nErr != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput diff --git a/app/post_test.go b/app/post_test.go index 552906b9f1d..a816e1bdf92 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -1970,3 +1970,43 @@ func TestCollapsedThreadFetch(t *testing.T) { require.NotEmpty(t, l.Posts[postRoot.Id].Participants[0].Email) }) } + +func TestReplyToPostWithLag(t *testing.T) { + if !replicaFlag { + t.Skipf("requires test flag -mysql-replica") + } + + th := Setup(t).InitBasic() + defer th.TearDown() + + if *th.App.Srv().Config().SqlSettings.DriverName != model.DATABASE_DRIVER_MYSQL { + t.Skipf("requires %q database driver", model.DATABASE_DRIVER_MYSQL) + } + + mainHelper.SQLStore.UpdateLicense(model.NewTestLicense("somelicense")) + + t.Run("replication lag time great than reply time", func(t *testing.T) { + err := mainHelper.SetReplicationLagForTesting(5) + require.Nil(t, err) + defer mainHelper.SetReplicationLagForTesting(0) + mainHelper.ToggleReplicasOn() + defer mainHelper.ToggleReplicasOff() + + root, err := th.App.CreatePost(&model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "root post", + }, th.BasicChannel, false, true) + require.Nil(t, err) + + reply, err := th.App.CreatePost(&model.Post{ + UserId: th.BasicUser2.Id, + ChannelId: th.BasicChannel.Id, + RootId: root.Id, + ParentId: root.Id, + Message: fmt.Sprintf("@%s", th.BasicUser2.Username), + }, th.BasicChannel, false, true) + require.Nil(t, err) + require.NotNil(t, reply) + }) +} diff --git a/app/server_test.go b/app/server_test.go index 3648d89c934..096b1e489bf 100644 --- a/app/server_test.go +++ b/app/server_test.go @@ -59,7 +59,7 @@ func TestReadReplicaDisabledBasedOnLicense(t *testing.T) { } else { dsn = os.Getenv("TEST_DATABASE_MYSQL_DSN") } - cfg.SqlSettings = *storetest.MakeSqlSettings(driverName) + cfg.SqlSettings = *storetest.MakeSqlSettings(driverName, false) if dsn != "" { cfg.SqlSettings.DataSource = &dsn } diff --git a/app/session.go b/app/session.go index b9f97f28ed0..33b97b67cd6 100644 --- a/app/session.go +++ b/app/session.go @@ -17,6 +17,7 @@ import ( "github.com/mattermost/mattermost-server/v5/model" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/store" + "github.com/mattermost/mattermost-server/v5/store/sqlstore" ) func (a *App) CreateSession(session *model.Session) (*model.Session, *model.AppError) { @@ -84,7 +85,7 @@ func (a *App) GetSession(token string) (*model.Session, *model.AppError) { if session.Id == "" { var nErr error - if session, nErr = a.Srv().Store.Session().Get(token); nErr == nil { + if session, nErr = a.Srv().Store.Session().Get(sqlstore.WithMaster(context.Background()), token); nErr == nil { if session != nil { if session.Token != token { return nil, model.NewAppError("GetSession", "api.context.invalid_token.error", map[string]interface{}{"Token": token, "Error": ""}, "session token is different from the one in DB", http.StatusUnauthorized) @@ -287,7 +288,7 @@ func (a *App) RevokeSessionsForDeviceId(userID string, deviceID string, currentS } func (a *App) GetSessionById(sessionID string) (*model.Session, *model.AppError) { - session, err := a.Srv().Store.Session().Get(sessionID) + session, err := a.Srv().Store.Session().Get(context.Background(), sessionID) if err != nil { return nil, model.NewAppError("GetSessionById", "app.session.get.app_error", nil, err.Error(), http.StatusBadRequest) } @@ -296,7 +297,7 @@ func (a *App) GetSessionById(sessionID string) (*model.Session, *model.AppError) } func (a *App) RevokeSessionById(sessionID string) *model.AppError { - session, err := a.Srv().Store.Session().Get(sessionID) + session, err := a.Srv().Store.Session().Get(context.Background(), sessionID) if err != nil { return model.NewAppError("RevokeSessionById", "app.session.get.app_error", nil, err.Error(), http.StatusBadRequest) } @@ -534,7 +535,7 @@ func (a *App) createSessionForUserAccessToken(tokenString string) (*model.Sessio func (a *App) RevokeUserAccessToken(token *model.UserAccessToken) *model.AppError { var session *model.Session - session, _ = a.Srv().Store.Session().Get(token.Token) + session, _ = a.Srv().Store.Session().Get(context.Background(), token.Token) if err := a.Srv().Store.UserAccessToken().Delete(token.Id); err != nil { return model.NewAppError("RevokeUserAccessToken", "app.user_access_token.delete.app_error", nil, err.Error(), http.StatusInternalServerError) @@ -549,7 +550,7 @@ func (a *App) RevokeUserAccessToken(token *model.UserAccessToken) *model.AppErro func (a *App) DisableUserAccessToken(token *model.UserAccessToken) *model.AppError { var session *model.Session - session, _ = a.Srv().Store.Session().Get(token.Token) + session, _ = a.Srv().Store.Session().Get(context.Background(), token.Token) if err := a.Srv().Store.UserAccessToken().UpdateTokenDisable(token.Id); err != nil { return model.NewAppError("DisableUserAccessToken", "app.user_access_token.update_token_disable.app_error", nil, err.Error(), http.StatusInternalServerError) @@ -564,7 +565,7 @@ func (a *App) DisableUserAccessToken(token *model.UserAccessToken) *model.AppErr func (a *App) EnableUserAccessToken(token *model.UserAccessToken) *model.AppError { var session *model.Session - session, _ = a.Srv().Store.Session().Get(token.Token) + session, _ = a.Srv().Store.Session().Get(context.Background(), token.Token) err := a.Srv().Store.UserAccessToken().UpdateTokenEnable(token.Id) if err != nil { diff --git a/app/session_test.go b/app/session_test.go index f5e264773ce..2cc7a3d3b5e 100644 --- a/app/session_test.go +++ b/app/session_test.go @@ -4,6 +4,7 @@ package app import ( + "context" "fmt" "os" "testing" @@ -347,7 +348,7 @@ func TestApp_ExtendExpiryIfNeeded(t *testing.T) { require.Equal(t, session.ExpiresAt, cachedSession.ExpiresAt) // check database was updated. - storedSession, nErr := th.App.Srv().Store.Session().Get(session.Token) + storedSession, nErr := th.App.Srv().Store.Session().Get(context.Background(), session.Token) require.NoError(t, nErr) require.Equal(t, session.ExpiresAt, storedSession.ExpiresAt) }) diff --git a/app/team.go b/app/team.go index 226ecb29a2d..f4942d9b3dd 100644 --- a/app/team.go +++ b/app/team.go @@ -23,6 +23,7 @@ import ( "github.com/mattermost/mattermost-server/v5/shared/i18n" "github.com/mattermost/mattermost-server/v5/shared/mlog" "github.com/mattermost/mattermost-server/v5/store" + "github.com/mattermost/mattermost-server/v5/store/sqlstore" ) func (a *App) CreateTeam(team *model.Team) (*model.Team, *model.AppError) { @@ -355,7 +356,7 @@ func (a *App) GetSchemeRolesForTeam(teamID string) (string, string, string, *mod } func (a *App) UpdateTeamMemberRoles(teamID string, userID string, newRoles string) (*model.TeamMember, *model.AppError) { - member, nErr := a.Srv().Store.Team().GetMember(teamID, userID) + member, nErr := a.Srv().Store.Team().GetMember(context.Background(), teamID, userID) if nErr != nil { var nfErr *store.ErrNotFound switch { @@ -701,7 +702,7 @@ func (a *App) joinUserToTeam(team *model.Team, user *model.User) (*model.TeamMem tm.SchemeAdmin = true } - rtm, err := a.Srv().Store.Team().GetMember(team.Id, user.Id) + rtm, err := a.Srv().Store.Team().GetMember(context.Background(), team.Id, user.Id) if err != nil { // Membership appears to be missing. Lets try to add. tmr, nErr := a.Srv().Store.Team().SaveMember(tm, *a.Config().TeamSettings.MaxUsersPerTeam) @@ -998,7 +999,7 @@ func (a *App) GetTeamsForUser(userID string) ([]*model.Team, *model.AppError) { } func (a *App) GetTeamMember(teamID, userID string) (*model.TeamMember, *model.AppError) { - teamMember, err := a.Srv().Store.Team().GetMember(teamID, userID) + teamMember, err := a.Srv().Store.Team().GetMember(sqlstore.WithMaster(context.Background()), teamID, userID) if err != nil { var nfErr *store.ErrNotFound switch { diff --git a/app/web_hub_test.go b/app/web_hub_test.go index 94f69d503dc..39816fc2d3c 100644 --- a/app/web_hub_test.go +++ b/app/web_hub_test.go @@ -173,7 +173,7 @@ func TestHubSessionRevokeRace(t *testing.T) { mockSessionStore := mocks.SessionStore{} mockSessionStore.On("UpdateLastActivityAt", "id1", mock.Anything).Return(nil) mockSessionStore.On("Save", mock.AnythingOfType("*model.Session")).Return(sess1, nil) - mockSessionStore.On("Get", "id1").Return(sess1, nil) + mockSessionStore.On("Get", mock.Anything, "id1").Return(sess1, nil) mockSessionStore.On("Remove", "id1").Return(nil) mockStatusStore := mocks.StatusStore{} diff --git a/build/docker-compose-generator/main.go b/build/docker-compose-generator/main.go index 4533d1d2d1d..d1abd940866 100644 --- a/build/docker-compose-generator/main.go +++ b/build/docker-compose-generator/main.go @@ -25,16 +25,17 @@ type Container struct { func main() { validServices := map[string]int{ - "mysql": 3306, - "postgres": 5432, - "minio": 9000, - "inbucket": 10080, - "openldap": 389, - "elasticsearch": 9200, - "dejavu": 1358, - "keycloak": 8080, - "prometheus": 9090, - "grafana": 3000, + "mysql": 3306, + "postgres": 5432, + "minio": 9000, + "inbucket": 10080, + "openldap": 389, + "elasticsearch": 9200, + "dejavu": 1358, + "keycloak": 8080, + "prometheus": 9090, + "grafana": 3000, + "mysql-read-replica": 3306, // FIXME: not recorgnizing the successfully running service on port 3307. } command := []string{} for _, arg := range os.Args[1:] { diff --git a/build/docker-compose.common.yml b/build/docker-compose.common.yml index eeb9fd365ad..f9c17432bfc 100644 --- a/build/docker-compose.common.yml +++ b/build/docker-compose.common.yml @@ -16,6 +16,28 @@ services: interval: 5s timeout: 10s retries: 3 + volumes: + - ./docker/mysql.conf.d/source.cnf:/etc/mysql/conf.d/mysql.cnf + mysql-read-replica: + image: "mysql:5.7" + restart: always + networks: + - mm-test + ports: + - 3307:3306 + environment: + MYSQL_ROOT_HOST: "%" + MYSQL_ROOT_PASSWORD: mostest + MYSQL_PASSWORD: mostest + MYSQL_USER: mmuser + MYSQL_DATABASE: mattermost_test + healthcheck: + test: ["CMD", "mysqladmin" ,"ping", "-h", "localhost"] + interval: 5s + timeout: 10s + retries: 3 + volumes: + - ./docker/mysql.conf.d/replica.cnf:/etc/mysql/conf.d/mysql.cnf postgres: image: "postgres:10" restart: always diff --git a/build/docker/mysql.conf.d/replica.cnf b/build/docker/mysql.conf.d/replica.cnf new file mode 100644 index 00000000000..a0ee6e162a7 --- /dev/null +++ b/build/docker/mysql.conf.d/replica.cnf @@ -0,0 +1,3 @@ +[mysqld] + +server-id = 2 \ No newline at end of file diff --git a/build/docker/mysql.conf.d/source.cnf b/build/docker/mysql.conf.d/source.cnf new file mode 100644 index 00000000000..bb0a1c35df2 --- /dev/null +++ b/build/docker/mysql.conf.d/source.cnf @@ -0,0 +1,4 @@ +[mysqld] + +server-id = 1 +log-bin = mysql-bin \ No newline at end of file diff --git a/docker-compose.makefile.yml b/docker-compose.makefile.yml index c66a6867c4c..8e5310efd02 100644 --- a/docker-compose.makefile.yml +++ b/docker-compose.makefile.yml @@ -8,6 +8,14 @@ services: extends: file: build/docker-compose.common.yml service: mysql + mysql-read-replica: + restart: 'no' + container_name: mattermost-mysql-read-replica + ports: + - "3307:3306" + extends: + file: build/docker-compose.common.yml + service: mysql-read-replica postgres: restart: 'no' container_name: mattermost-postgres diff --git a/docker-compose.yaml b/docker-compose.yaml index cccf264c46a..3e95605b27c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -7,6 +7,13 @@ services: extends: file: build/docker-compose.common.yml service: mysql + mysql-read-replica: + container_name: mattermost-mysql-read-replica + ports: + - "3307:3306" + extends: + file: build/docker-compose.common.yml + service: mysql-read-replica postgres: container_name: mattermost-postgres ports: diff --git a/scripts/replica-lag-set.sh b/scripts/replica-lag-set.sh new file mode 100755 index 00000000000..4c202a73a50 --- /dev/null +++ b/scripts/replica-lag-set.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +stmt="STOP SLAVE SQL_THREAD FOR CHANNEL '';CHANGE MASTER TO MASTER_DELAY = $1;START SLAVE SQL_THREAD FOR CHANNEL '';SHOW SLAVE STATUS\G;" +docker exec mattermost-mysql-read-replica sh -c "export MYSQL_PWD=mostest; mysql -u root -e \"$stmt\"" | grep SQL_Delay \ No newline at end of file diff --git a/scripts/replica-mysql-config.sh b/scripts/replica-mysql-config.sh new file mode 100755 index 00000000000..6728e44e3a1 --- /dev/null +++ b/scripts/replica-mysql-config.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +until docker exec mattermost-mysql sh -c 'mysql -u root -pmostest -e ";"' +do + echo "Waiting for mattermost-mysql database connection..." + sleep 4 +done + +priv_stmt='GRANT REPLICATION SLAVE ON *.* TO "mmuser"@"%" IDENTIFIED BY "mostest"; FLUSH PRIVILEGES;' +docker exec mattermost-mysql sh -c "mysql -u root -pmostest -e '$priv_stmt'" + +until docker-compose -f docker-compose.makefile.yml exec mysql-read-replica sh -c 'mysql -u root -pmostest -e ";"' +do + echo "Waiting for mysql-read-replica database connection..." + sleep 4 +done + +docker-ip() { + docker inspect --format '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' "$@" +} + +MS_STATUS=`docker exec mattermost-mysql sh -c 'mysql -u root -pmostest -e "SHOW MASTER STATUS"'` +CURRENT_LOG=`echo $MS_STATUS | awk '{print $6}'` +CURRENT_POS=`echo $MS_STATUS | awk '{print $7}'` + +start_slave_stmt="CHANGE MASTER TO MASTER_HOST='$(docker-ip mattermost-mysql)',MASTER_USER='mmuser',MASTER_PASSWORD='mostest',MASTER_LOG_FILE='$CURRENT_LOG',MASTER_LOG_POS=$CURRENT_POS; START SLAVE;" +start_slave_cmd='mysql -u root -pmostest -e "' +start_slave_cmd+="$start_slave_stmt" +start_slave_cmd+='"' +docker exec mattermost-mysql-read-replica sh -c "$start_slave_cmd" + +docker exec mattermost-mysql-read-replica sh -c "mysql -u root -pmostest -e 'SHOW SLAVE STATUS \G'" diff --git a/services/searchengine/bleveengine/bleve_test.go b/services/searchengine/bleveengine/bleve_test.go index bc1f90c7c76..e86ba2e40d4 100644 --- a/services/searchengine/bleveengine/bleve_test.go +++ b/services/searchengine/bleveengine/bleve_test.go @@ -49,7 +49,7 @@ func (s *BleveEngineTestSuite) setupStore() { if driverName == "" { driverName = model.DATABASE_DRIVER_POSTGRES } - s.SQLSettings = storetest.MakeSqlSettings(driverName) + s.SQLSettings = storetest.MakeSqlSettings(driverName, false) s.SQLStore = sqlstore.New(*s.SQLSettings, nil) cfg := &model.Config{} diff --git a/store/localcachelayer/layer_test.go b/store/localcachelayer/layer_test.go index 6123e5f740e..7280b39c3a2 100644 --- a/store/localcachelayer/layer_test.go +++ b/store/localcachelayer/layer_test.go @@ -26,7 +26,7 @@ var storeTypes []*storeType func newStoreType(name, driver string) *storeType { return &storeType{ Name: name, - SqlSettings: storetest.MakeSqlSettings(driver), + SqlSettings: storetest.MakeSqlSettings(driver, false), } } diff --git a/store/opentracinglayer/opentracinglayer.go b/store/opentracinglayer/opentracinglayer.go index 4b74fa3e052..c74df984e6b 100644 --- a/store/opentracinglayer/opentracinglayer.go +++ b/store/opentracinglayer/opentracinglayer.go @@ -4934,7 +4934,7 @@ func (s *OpenTracingLayerPostStore) Delete(postID string, time int64, deleteByID return err } -func (s *OpenTracingLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { +func (s *OpenTracingLayerPostStore) Get(ctx context.Context, id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.Get") s.Root.Store.SetContext(newCtx) @@ -4943,7 +4943,7 @@ func (s *OpenTracingLayerPostStore) Get(id string, skipFetchThreads bool, collap }() defer span.Finish() - result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + result, err := s.PostStore.Get(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -6218,7 +6218,7 @@ func (s *OpenTracingLayerSessionStore) Cleanup(expiryTime int64, batchSize int64 } -func (s *OpenTracingLayerSessionStore) Get(sessionIDOrToken string) (*model.Session, error) { +func (s *OpenTracingLayerSessionStore) Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.Get") s.Root.Store.SetContext(newCtx) @@ -6227,7 +6227,7 @@ func (s *OpenTracingLayerSessionStore) Get(sessionIDOrToken string) (*model.Sess }() defer span.Finish() - result, err := s.SessionStore.Get(sessionIDOrToken) + result, err := s.SessionStore.Get(ctx, sessionIDOrToken) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -7095,7 +7095,7 @@ func (s *OpenTracingLayerTeamStore) GetChannelUnreadsForTeam(teamID string, user return result, err } -func (s *OpenTracingLayerTeamStore) GetMember(teamID string, userId string) (*model.TeamMember, error) { +func (s *OpenTracingLayerTeamStore) GetMember(ctx context.Context, teamID string, userId string) (*model.TeamMember, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "TeamStore.GetMember") s.Root.Store.SetContext(newCtx) @@ -7104,7 +7104,7 @@ func (s *OpenTracingLayerTeamStore) GetMember(teamID string, userId string) (*mo }() defer span.Finish() - result, err := s.TeamStore.GetMember(teamID, userId) + result, err := s.TeamStore.GetMember(ctx, teamID, userId) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) diff --git a/store/retrylayer/retrylayer.go b/store/retrylayer/retrylayer.go index 405f07b39f0..b4d9ab95c0d 100644 --- a/store/retrylayer/retrylayer.go +++ b/store/retrylayer/retrylayer.go @@ -5342,11 +5342,11 @@ func (s *RetryLayerPostStore) Delete(postID string, time int64, deleteByID strin } -func (s *RetryLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { +func (s *RetryLayerPostStore) Get(ctx context.Context, id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { tries := 0 for { - result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + result, err := s.PostStore.Get(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err == nil { return result, nil } @@ -6726,11 +6726,11 @@ func (s *RetryLayerSessionStore) Cleanup(expiryTime int64, batchSize int64) { } -func (s *RetryLayerSessionStore) Get(sessionIDOrToken string) (*model.Session, error) { +func (s *RetryLayerSessionStore) Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) { tries := 0 for { - result, err := s.SessionStore.Get(sessionIDOrToken) + result, err := s.SessionStore.Get(ctx, sessionIDOrToken) if err == nil { return result, nil } @@ -7692,11 +7692,11 @@ func (s *RetryLayerTeamStore) GetChannelUnreadsForTeam(teamID string, userId str } -func (s *RetryLayerTeamStore) GetMember(teamID string, userId string) (*model.TeamMember, error) { +func (s *RetryLayerTeamStore) GetMember(ctx context.Context, teamID string, userId string) (*model.TeamMember, error) { tries := 0 for { - result, err := s.TeamStore.GetMember(teamID, userId) + result, err := s.TeamStore.GetMember(ctx, teamID, userId) if err == nil { return result, nil } diff --git a/store/searchlayer/layer_test.go b/store/searchlayer/layer_test.go index 1a47b953b6c..adaa4579dc0 100644 --- a/store/searchlayer/layer_test.go +++ b/store/searchlayer/layer_test.go @@ -23,7 +23,7 @@ func TestUpdateConfigRace(t *testing.T) { if driverName == "" { driverName = model.DATABASE_DRIVER_POSTGRES } - settings := storetest.MakeSqlSettings(driverName) + settings := storetest.MakeSqlSettings(driverName, false) store := sqlstore.New(*settings, nil) cfg := &model.Config{} diff --git a/store/searchlayer/post_layer.go b/store/searchlayer/post_layer.go index 6f4d4dea759..bd3ed366bfb 100644 --- a/store/searchlayer/post_layer.go +++ b/store/searchlayer/post_layer.go @@ -4,6 +4,8 @@ package searchlayer import ( + "context" + "github.com/pkg/errors" "github.com/mattermost/mattermost-server/v5/model" @@ -108,7 +110,7 @@ func (s SearchPostStore) Delete(postId string, date int64, deletedByID string) e err := s.PostStore.Delete(postId, date, deletedByID) if err == nil { - postList, err2 := s.PostStore.Get(postId, true, false, false) + postList, err2 := s.PostStore.Get(context.Background(), postId, true, false, false) if postList != nil && len(postList.Order) > 0 { if err2 != nil { s.deletePostIndex(postList.Posts[postList.Order[0]]) diff --git a/store/sqlstore/group_store.go b/store/sqlstore/group_store.go index 3c28156199c..8e6d5910ef5 100644 --- a/store/sqlstore/group_store.go +++ b/store/sqlstore/group_store.go @@ -734,7 +734,7 @@ func (s *SqlGroupStore) TeamMembersToAdd(since int64, teamID *string) ([]*model. var teamMembers []*model.UserTeamIDPair - _, err = s.GetReplica().Select(&teamMembers, query, params...) + _, err = s.GetMaster().Select(&teamMembers, query, params...) if err != nil { return nil, errors.Wrap(err, "failed to find UserTeamIDPairs") } @@ -771,7 +771,7 @@ func (s *SqlGroupStore) ChannelMembersToAdd(since int64, channelID *string) ([]* var channelMembers []*model.UserChannelIDPair - _, err = s.GetReplica().Select(&channelMembers, query, params...) + _, err = s.GetMaster().Select(&channelMembers, query, params...) if err != nil { return nil, errors.Wrap(err, "failed to find UserChannelIDPairs") } @@ -1447,7 +1447,7 @@ func (s *SqlGroupStore) PermittedSyncableAdmins(syncableID string, syncableType } var userIDs []string - if _, err = s.GetReplica().Select(&userIDs, query, args...); err != nil { + if _, err = s.GetMaster().Select(&userIDs, query, args...); err != nil { return nil, errors.Wrapf(err, "failed to find User ids") } diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index a065c0ce030..4bd1a602e16 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -455,7 +455,7 @@ func (s *SqlPostStore) getPostWithCollapsedThreads(id string, extended bool) (*m return s.prepareThreadedResponse([]*postWithExtra{&post}, extended, false) } -func (s *SqlPostStore) Get(id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) { +func (s *SqlPostStore) Get(ctx context.Context, id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) { if collapsedThreads { return s.getPostWithCollapsedThreads(id, collapsedThreadsExtended) } @@ -467,7 +467,7 @@ func (s *SqlPostStore) Get(id string, skipFetchThreads, collapsedThreads, collap var post model.Post postFetchQuery := "SELECT p.*, (SELECT count(Posts.Id) FROM Posts WHERE Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0) as ReplyCount FROM Posts p WHERE p.Id = :Id AND p.DeleteAt = 0" - err := s.GetReplica().SelectOne(&post, postFetchQuery, map[string]interface{}{"Id": id}) + err := s.DBFromContext(ctx).SelectOne(&post, postFetchQuery, map[string]interface{}{"Id": id}) if err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Post", id) diff --git a/store/sqlstore/session_store.go b/store/sqlstore/session_store.go index 6d01d910674..ec96fcb168f 100644 --- a/store/sqlstore/session_store.go +++ b/store/sqlstore/session_store.go @@ -73,10 +73,10 @@ func (me SqlSessionStore) Save(session *model.Session) (*model.Session, error) { return session, nil } -func (me SqlSessionStore) Get(sessionIdOrToken string) (*model.Session, error) { +func (me SqlSessionStore) Get(ctx context.Context, sessionIdOrToken string) (*model.Session, error) { var sessions []*model.Session - if _, err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE Token = :Token OR Id = :Id LIMIT 1", map[string]interface{}{"Token": sessionIdOrToken, "Id": sessionIdOrToken}); err != nil { + if _, err := me.DBFromContext(ctx).Select(&sessions, "SELECT * FROM Sessions WHERE Token = :Token OR Id = :Id LIMIT 1", map[string]interface{}{"Token": sessionIdOrToken, "Id": sessionIdOrToken}); err != nil { return nil, errors.Wrapf(err, "failed to find Sessions with sessionIdOrToken=%s", sessionIdOrToken) } else if len(sessions) == 0 { return nil, store.NewErrNotFound("Session", fmt.Sprintf("sessionIdOrToken=%s", sessionIdOrToken)) @@ -249,7 +249,7 @@ func (me SqlSessionStore) UpdateDeviceId(id string, deviceId string, expiresAt i } func (me SqlSessionStore) UpdateProps(session *model.Session) error { - oldSession, err := me.Get(session.Id) + oldSession, err := me.Get(context.Background(), session.Id) if err != nil { return err } diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index 1566a8880e4..5f347cb6c14 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -132,7 +132,7 @@ type SqlStore struct { rrCounter int64 srCounter int64 master *gorp.DbMap - replicas []*gorp.DbMap + Replicas []*gorp.DbMap searchReplicas []*gorp.DbMap stores SqlStoreStores settings *model.SqlSettings @@ -318,9 +318,9 @@ func (ss *SqlStore) initConnection() { ss.master = setupConnection("master", *ss.settings.DataSource, ss.settings) if len(ss.settings.DataSourceReplicas) > 0 { - ss.replicas = make([]*gorp.DbMap, len(ss.settings.DataSourceReplicas)) + ss.Replicas = make([]*gorp.DbMap, len(ss.settings.DataSourceReplicas)) for i, replica := range ss.settings.DataSourceReplicas { - ss.replicas[i] = setupConnection(fmt.Sprintf("replica-%v", i), replica, ss.settings) + ss.Replicas[i] = setupConnection(fmt.Sprintf("replica-%v", i), replica, ss.settings) } } @@ -395,8 +395,8 @@ func (ss *SqlStore) GetReplica() *gorp.DbMap { return ss.GetMaster() } - rrNum := atomic.AddInt64(&ss.rrCounter, 1) % int64(len(ss.replicas)) - return ss.replicas[rrNum] + rrNum := atomic.AddInt64(&ss.rrCounter, 1) % int64(len(ss.Replicas)) + return ss.Replicas[rrNum] } func (ss *SqlStore) TotalMasterDbConnections() int { @@ -409,7 +409,7 @@ func (ss *SqlStore) TotalReadDbConnections() int { } count := 0 - for _, db := range ss.replicas { + for _, db := range ss.Replicas { count = count + db.Db.Stats().OpenConnections } @@ -1008,9 +1008,9 @@ func IsUniqueConstraintError(err error, indexName []string) bool { } func (ss *SqlStore) GetAllConns() []*gorp.DbMap { - all := make([]*gorp.DbMap, len(ss.replicas)+1) - copy(all, ss.replicas) - all[len(ss.replicas)] = ss.master + all := make([]*gorp.DbMap, len(ss.Replicas)+1) + copy(all, ss.Replicas) + all[len(ss.Replicas)] = ss.master return all } @@ -1033,7 +1033,7 @@ func (ss *SqlStore) RecycleDBConnections(d time.Duration) { func (ss *SqlStore) Close() { ss.master.Db.Close() - for _, replica := range ss.replicas { + for _, replica := range ss.Replicas { replica.Db.Close() } @@ -1210,6 +1210,10 @@ func (ss *SqlStore) UpdateLicense(license *model.License) { ss.license = license } +func (ss *SqlStore) GetLicense() *model.License { + return ss.license +} + func (ss *SqlStore) migrate(direction migrationDirection) error { var driver database.Driver var err error diff --git a/store/sqlstore/store_test.go b/store/sqlstore/store_test.go index 37e7b5bb111..cf697f51e27 100644 --- a/store/sqlstore/store_test.go +++ b/store/sqlstore/store_test.go @@ -36,7 +36,7 @@ var storeTypes []*storeType func newStoreType(name, driver string) *storeType { return &storeType{ Name: name, - SqlSettings: storetest.MakeSqlSettings(driver), + SqlSettings: storetest.MakeSqlSettings(driver, false), } } @@ -560,9 +560,9 @@ func TestVersionString(t *testing.T) { func makeSqlSettings(driver string) *model.SqlSettings { switch driver { case model.DATABASE_DRIVER_POSTGRES: - return storetest.MakeSqlSettings(driver) + return storetest.MakeSqlSettings(driver, false) case model.DATABASE_DRIVER_MYSQL: - return storetest.MakeSqlSettings(driver) + return storetest.MakeSqlSettings(driver, false) } return nil diff --git a/store/sqlstore/team_store.go b/store/sqlstore/team_store.go index fa2a433da26..ed5bbc95b8a 100644 --- a/store/sqlstore/team_store.go +++ b/store/sqlstore/team_store.go @@ -988,7 +988,7 @@ func (s SqlTeamStore) UpdateMember(member *model.TeamMember) (*model.TeamMember, } // GetMember returns a single member of the team that matches the teamId and userId provided as parameters. -func (s SqlTeamStore) GetMember(teamId string, userId string) (*model.TeamMember, error) { +func (s SqlTeamStore) GetMember(ctx context.Context, teamId string, userId string) (*model.TeamMember, error) { query := s.getTeamMembersWithSchemeSelectQuery(). Where(sq.Eq{"TeamMembers.TeamId": teamId}). Where(sq.Eq{"TeamMembers.UserId": userId}) @@ -999,7 +999,7 @@ func (s SqlTeamStore) GetMember(teamId string, userId string) (*model.TeamMember } var dbMember teamMemberWithSchemeRoles - err = s.GetReplica().SelectOne(&dbMember, queryString, args...) + err = s.DBFromContext(ctx).SelectOne(&dbMember, queryString, args...) if err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("TeamMember", fmt.Sprintf("teamId=%s, userId=%s", teamId, userId)) diff --git a/store/store.go b/store/store.go index 7df0890659d..c25d4c120fd 100644 --- a/store/store.go +++ b/store/store.go @@ -97,7 +97,7 @@ type TeamStore interface { SaveMember(member *model.TeamMember, maxUsersPerTeam int) (*model.TeamMember, error) UpdateMember(member *model.TeamMember) (*model.TeamMember, error) UpdateMultipleMembers(members []*model.TeamMember) ([]*model.TeamMember, error) - GetMember(teamID string, userId string) (*model.TeamMember, error) + GetMember(ctx context.Context, teamID string, userId string) (*model.TeamMember, error) GetMembers(teamID string, offset int, limit int, teamMembersGetOptions *model.TeamMembersGetOptions) ([]*model.TeamMember, error) GetMembersByIds(teamID string, userIds []string, restrictions *model.ViewUsersRestrictions) ([]*model.TeamMember, error) GetTotalMemberCount(teamID string, restrictions *model.ViewUsersRestrictions) (int64, error) @@ -274,7 +274,7 @@ type PostStore interface { SaveMultiple(posts []*model.Post) ([]*model.Post, int, error) Save(post *model.Post) (*model.Post, error) Update(newPost *model.Post, oldPost *model.Post) (*model.Post, error) - Get(id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) + Get(ctx context.Context, id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) GetSingle(id string) (*model.Post, error) Delete(postID string, time int64, deleteByID string) error PermanentDeleteByUser(userId string) error @@ -393,7 +393,7 @@ type BotStore interface { } type SessionStore interface { - Get(sessionIDOrToken string) (*model.Session, error) + Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) Save(session *model.Session) (*model.Session, error) GetSessions(userId string) ([]*model.Session, error) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index d640e8abe27..d4e22a9b6fd 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -5,6 +5,8 @@ package mocks import ( + context "context" + model "github.com/mattermost/mattermost-server/v5/model" mock "github.com/stretchr/testify/mock" ) @@ -100,13 +102,13 @@ func (_m *PostStore) Delete(postID string, time int64, deleteByID string) error return r0 } -// Get provides a mock function with given fields: id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended -func (_m *PostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { - ret := _m.Called(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) +// Get provides a mock function with given fields: ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended +func (_m *PostStore) Get(ctx context.Context, id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { + ret := _m.Called(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) var r0 *model.PostList - if rf, ok := ret.Get(0).(func(string, bool, bool, bool) *model.PostList); ok { - r0 = rf(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + if rf, ok := ret.Get(0).(func(context.Context, string, bool, bool, bool) *model.PostList); ok { + r0 = rf(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.PostList) @@ -114,8 +116,8 @@ func (_m *PostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool } var r1 error - if rf, ok := ret.Get(1).(func(string, bool, bool, bool) error); ok { - r1 = rf(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + if rf, ok := ret.Get(1).(func(context.Context, string, bool, bool, bool) error); ok { + r1 = rf(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) } else { r1 = ret.Error(1) } diff --git a/store/storetest/mocks/SessionStore.go b/store/storetest/mocks/SessionStore.go index 14178228fcf..2113886f511 100644 --- a/store/storetest/mocks/SessionStore.go +++ b/store/storetest/mocks/SessionStore.go @@ -5,6 +5,8 @@ package mocks import ( + context "context" + model "github.com/mattermost/mattermost-server/v5/model" mock "github.com/stretchr/testify/mock" ) @@ -40,13 +42,13 @@ func (_m *SessionStore) Cleanup(expiryTime int64, batchSize int64) { _m.Called(expiryTime, batchSize) } -// Get provides a mock function with given fields: sessionIDOrToken -func (_m *SessionStore) Get(sessionIDOrToken string) (*model.Session, error) { - ret := _m.Called(sessionIDOrToken) +// Get provides a mock function with given fields: ctx, sessionIDOrToken +func (_m *SessionStore) Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) { + ret := _m.Called(ctx, sessionIDOrToken) var r0 *model.Session - if rf, ok := ret.Get(0).(func(string) *model.Session); ok { - r0 = rf(sessionIDOrToken) + if rf, ok := ret.Get(0).(func(context.Context, string) *model.Session); ok { + r0 = rf(ctx, sessionIDOrToken) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Session) @@ -54,8 +56,8 @@ func (_m *SessionStore) Get(sessionIDOrToken string) (*model.Session, error) { } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(sessionIDOrToken) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, sessionIDOrToken) } else { r1 = ret.Error(1) } diff --git a/store/storetest/mocks/TeamStore.go b/store/storetest/mocks/TeamStore.go index 4fd343db402..ac036aef26a 100644 --- a/store/storetest/mocks/TeamStore.go +++ b/store/storetest/mocks/TeamStore.go @@ -462,13 +462,13 @@ func (_m *TeamStore) GetChannelUnreadsForTeam(teamID string, userId string) ([]* return r0, r1 } -// GetMember provides a mock function with given fields: teamID, userId -func (_m *TeamStore) GetMember(teamID string, userId string) (*model.TeamMember, error) { - ret := _m.Called(teamID, userId) +// GetMember provides a mock function with given fields: ctx, teamID, userId +func (_m *TeamStore) GetMember(ctx context.Context, teamID string, userId string) (*model.TeamMember, error) { + ret := _m.Called(ctx, teamID, userId) var r0 *model.TeamMember - if rf, ok := ret.Get(0).(func(string, string) *model.TeamMember); ok { - r0 = rf(teamID, userId) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *model.TeamMember); ok { + r0 = rf(ctx, teamID, userId) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.TeamMember) @@ -476,8 +476,8 @@ func (_m *TeamStore) GetMember(teamID string, userId string) (*model.TeamMember, } var r1 error - if rf, ok := ret.Get(1).(func(string, string) error); ok { - r1 = rf(teamID, userId) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, teamID, userId) } else { r1 = ret.Error(1) } diff --git a/store/storetest/oauth_store.go b/store/storetest/oauth_store.go index a7f442650a8..b28ef202696 100644 --- a/store/storetest/oauth_store.go +++ b/store/storetest/oauth_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -384,7 +385,7 @@ func testOAuthStoreDeleteApp(t *testing.T, ss store.Store) { err = ss.OAuth().DeleteApp(a1.Id) require.NoError(t, err) - _, nErr = ss.Session().Get(s1.Token) + _, nErr = ss.Session().Get(context.Background(), s1.Token) require.Error(t, nErr, "should error - session should be deleted") _, err = ss.OAuth().GetAccessData(s1.Token) diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index b7fec35a4f7..341d18cfd55 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "fmt" "sort" "strings" @@ -416,14 +417,14 @@ func testPostStoreGet(t *testing.T, ss store.Store) { etag2 := ss.Post().GetEtag(o1.ChannelId, false, false) require.Equal(t, 0, strings.Index(etag2, fmt.Sprintf("%v.%v", model.CurrentVersion, o1.UpdateAt)), "Invalid Etag") - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) require.Equal(t, r1.Posts[o1.Id].CreateAt, o1.CreateAt, "invalid returned post") - _, err = ss.Post().Get("123", false, false, false) + _, err = ss.Post().Get(context.Background(), "123", false, false, false) require.Error(t, err, "Missing id should have failed") - _, err = ss.Post().Get("", false, false, false) + _, err = ss.Post().Get(context.Background(), "", false, false, false) require.Error(t, err, "should fail for blank post ids") } @@ -468,15 +469,15 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.NoError(t, err) - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o1.Id, false, false, false) + r2, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false, false, false) + r3, err := ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3 := r3.Posts[o3.Id] @@ -487,7 +488,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o1a, ro1) require.NoError(t, err) - r1, err = ss.Post().Get(o1.Id, false, false, false) + r1, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1a := r1.Posts[o1.Id] @@ -498,7 +499,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o2a, ro2) require.NoError(t, err) - r2, err = ss.Post().Get(o1.Id, false, false, false) + r2, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro2a := r2.Posts[o2.Id] @@ -509,7 +510,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o3a, ro3) require.NoError(t, err) - r3, err = ss.Post().Get(o3.Id, false, false, false) + r3, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3a := r3.Posts[o3.Id] @@ -525,7 +526,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { }) require.NoError(t, err) - r4, err := ss.Post().Get(o4.Id, false, false, false) + r4, err := ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) ro4 := r4.Posts[o4.Id] @@ -535,7 +536,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o4a, ro4) require.NoError(t, err) - r4, err = ss.Post().Get(o4.Id, false, false, false) + r4, err = ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) ro4a := r4.Posts[o4.Id] @@ -556,7 +557,7 @@ func testPostStoreDelete(t *testing.T, ss store.Store) { o1, err := ss.Post().Save(o1) require.NoError(t, err) - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) require.Equal(t, r1.Posts[o1.Id].CreateAt, o1.CreateAt, "invalid returned post") @@ -569,7 +570,7 @@ func testPostStoreDelete(t *testing.T, ss store.Store) { assert.Equal(t, deleteByID, actual, "Expected (*Post).Props[model.POST_PROPS_DELETE_BY] to be %v but got %v.", deleteByID, actual) - r3, err := ss.Post().Get(o1.Id, false, false, false) + r3, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.Error(t, err, "Missing id should have failed - PostList %v", r3) etag2 := ss.Post().GetEtag(o1.ChannelId, false, false) @@ -596,10 +597,10 @@ func testPostStoreDelete1Level(t *testing.T, ss store.Store) { err = ss.Post().Delete(o1.Id, model.GetMillis(), "") require.NoError(t, err) - _, err = ss.Post().Get(o1.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o2.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") } @@ -639,16 +640,16 @@ func testPostStoreDelete2Level(t *testing.T, ss store.Store) { err = ss.Post().Delete(o1.Id, model.GetMillis(), "") require.NoError(t, err) - _, err = ss.Post().Get(o1.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o2.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o3.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o4.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) } @@ -679,16 +680,16 @@ func testPostStorePermDelete1Level(t *testing.T, ss store.Store) { err2 := ss.Post().PermanentDeleteByUser(o2.UserId) require.NoError(t, err2) - _, err = ss.Post().Get(o1.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err, "Deleted id shouldn't have failed") - _, err = ss.Post().Get(o2.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o2.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") err = ss.Post().PermanentDeleteByChannel(o3.ChannelId) require.NoError(t, err) - _, err = ss.Post().Get(o3.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") } @@ -719,13 +720,13 @@ func testPostStorePermDelete1Level2(t *testing.T, ss store.Store) { err2 := ss.Post().PermanentDeleteByUser(o1.UserId) require.NoError(t, err2) - _, err = ss.Post().Get(o1.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o2.Id, false, false, false) require.Error(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o3.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err, "Deleted id should have failed") } @@ -755,7 +756,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.NoError(t, err) - pl, err := ss.Post().Get(o1.Id, false, false, false) + pl, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) require.Len(t, pl.Posts, 3, "invalid returned post") @@ -763,7 +764,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { dErr := ss.Post().Delete(o3.Id, model.GetMillis(), "") require.NoError(t, dErr) - pl, err = ss.Post().Get(o1.Id, false, false, false) + pl, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) require.Len(t, pl.Posts, 2, "invalid returned post") @@ -771,7 +772,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { dErr = ss.Post().Delete(o2.Id, model.GetMillis(), "") require.NoError(t, dErr) - pl, err = ss.Post().Get(o1.Id, false, false, false) + pl, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) require.Len(t, pl.Posts, 1, "invalid returned post") @@ -2241,23 +2242,23 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { }) require.NoError(t, err) - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false, false, false) + r2, err := ss.Post().Get(context.Background(), o2.Id, false, false, false) require.NoError(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false, false, false) + r3, err := ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3 := r3.Posts[o3.Id] - r4, err := ss.Post().Get(o4.Id, false, false, false) + r4, err := ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) ro4 := r4.Posts[o4.Id] - r5, err := ss.Post().Get(o5.Id, false, false, false) + r5, err := ss.Post().Get(context.Background(), o5.Id, false, false, false) require.NoError(t, err) ro5 := r5.Posts[o5.Id] @@ -2283,15 +2284,15 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, -1, errIdx) - r1, nErr := ss.Post().Get(o1.Id, false, false, false) + r1, nErr := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, nErr) ro1a := r1.Posts[o1.Id] - r2, nErr = ss.Post().Get(o1.Id, false, false, false) + r2, nErr = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, nErr) ro2a := r2.Posts[o2.Id] - r3, nErr = ss.Post().Get(o3.Id, false, false, false) + r3, nErr = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, nErr) ro3a := r3.Posts[o3.Id] @@ -2313,11 +2314,11 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, -1, errIdx) - r4, nErr := ss.Post().Get(o4.Id, false, false, false) + r4, nErr := ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, nErr) ro4a := r4.Posts[o4.Id] - r5, nErr = ss.Post().Get(o5.Id, false, false, false) + r5, nErr = ss.Post().Get(context.Background(), o5.Id, false, false, false) require.NoError(t, nErr) ro5a := r5.Posts[o5.Id] @@ -2360,19 +2361,19 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { }) require.NoError(t, err) - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false, false, false) + r2, err := ss.Post().Get(context.Background(), o2.Id, false, false, false) require.NoError(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false, false, false) + r3, err := ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3 := r3.Posts[o3.Id] - r4, err := ss.Post().Get(o4.Id, false, false, false) + r4, err := ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) ro4 := r4.Posts[o4.Id] @@ -2397,15 +2398,15 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { _, err = ss.Post().Overwrite(o3a) require.NoError(t, err) - r1, err = ss.Post().Get(o1.Id, false, false, false) + r1, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1a := r1.Posts[o1.Id] - r2, err = ss.Post().Get(o1.Id, false, false, false) + r2, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro2a := r2.Posts[o2.Id] - r3, err = ss.Post().Get(o3.Id, false, false, false) + r3, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3a := r3.Posts[o3.Id] @@ -2421,7 +2422,7 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { _, err = ss.Post().Overwrite(o4a) require.NoError(t, err) - r4, err = ss.Post().Get(o4.Id, false, false, false) + r4, err = ss.Post().Get(context.Background(), o4.Id, false, false, false) require.NoError(t, err) ro4a := r4.Posts[o4.Id] @@ -2452,15 +2453,15 @@ func testPostStoreGetPostsByIds(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.NoError(t, err) - r1, err := ss.Post().Get(o1.Id, false, false, false) + r1, err := ss.Post().Get(context.Background(), o1.Id, false, false, false) require.NoError(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false, false, false) + r2, err := ss.Post().Get(context.Background(), o2.Id, false, false, false) require.NoError(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false, false, false) + r3, err := ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err) ro3 := r3.Posts[o3.Id] @@ -2567,13 +2568,13 @@ func testPostStorePermanentDeleteBatch(t *testing.T, ss store.Store) { _, err = ss.Post().PermanentDeleteBatch(2000, 1000) require.NoError(t, err) - _, err = ss.Post().Get(o1.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o1.Id, false, false, false) require.Error(t, err, "Should have not found post 1 after purge") - _, err = ss.Post().Get(o2.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o2.Id, false, false, false) require.Error(t, err, "Should have not found post 2 after purge") - _, err = ss.Post().Get(o3.Id, false, false, false) + _, err = ss.Post().Get(context.Background(), o3.Id, false, false, false) require.NoError(t, err, "Should have not found post 3 after purge") } diff --git a/store/storetest/reaction_store.go b/store/storetest/reaction_store.go index 65b9e14de89..9c12ea9d3bc 100644 --- a/store/storetest/reaction_store.go +++ b/store/storetest/reaction_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "sync" "testing" "time" @@ -52,7 +53,7 @@ func testReactionSave(t *testing.T, ss store.Store) { assert.Zero(t, saved.DeleteAt, "should've saved reaction delete_at with zero value and returned it") var secondUpdateAt int64 - postList, err := ss.Post().Get(reaction1.PostId, false, false, false) + postList, err := ss.Post().Get(context.Background(), reaction1.PostId, false, false, false) require.NoError(t, err) assert.True(t, postList.Posts[post.Id].HasReactions, "should've set HasReactions = true on post") @@ -76,7 +77,7 @@ func testReactionSave(t *testing.T, ss store.Store) { _, nErr = ss.Reaction().Save(reaction2) require.NoError(t, nErr) - postList, err = ss.Post().Get(reaction2.PostId, false, false, false) + postList, err = ss.Post().Get(context.Background(), reaction2.PostId, false, false, false) require.NoError(t, err) assert.NotEqual(t, postList.Posts[post.Id].UpdateAt, secondUpdateAt, "should've marked post as updated even if HasReactions doesn't change") @@ -126,7 +127,7 @@ func testReactionDelete(t *testing.T, ss store.Store) { _, nErr := ss.Reaction().Save(reaction) require.NoError(t, nErr) - result, err := ss.Post().Get(reaction.PostId, false, false, false) + result, err := ss.Post().Get(context.Background(), reaction.PostId, false, false, false) require.NoError(t, err) firstUpdateAt := result.Posts[post.Id].UpdateAt @@ -139,7 +140,7 @@ func testReactionDelete(t *testing.T, ss store.Store) { assert.Empty(t, reactions, "should've deleted reaction") - postList, err := ss.Post().Get(post.Id, false, false, false) + postList, err := ss.Post().Get(context.Background(), post.Id, false, false, false) require.NoError(t, err) assert.False(t, postList.Posts[post.Id].HasReactions, "should've set HasReactions = false on post") @@ -362,15 +363,15 @@ func testReactionDeleteAllWithEmojiName(t *testing.T, ss store.Store, s SqlStore assert.Empty(t, returned, "should've only removed reactions with emoji name") // check that the posts are updated - postList, err := ss.Post().Get(post.Id, false, false, false) + postList, err := ss.Post().Get(context.Background(), post.Id, false, false, false) require.NoError(t, err) assert.True(t, postList.Posts[post.Id].HasReactions, "post should still have reactions") - postList, err = ss.Post().Get(post2.Id, false, false, false) + postList, err = ss.Post().Get(context.Background(), post2.Id, false, false, false) require.NoError(t, err) assert.True(t, postList.Posts[post2.Id].HasReactions, "post should still have reactions") - postList, err = ss.Post().Get(post3.Id, false, false, false) + postList, err = ss.Post().Get(context.Background(), post3.Id, false, false, false) require.NoError(t, err) assert.False(t, postList.Posts[post3.Id].HasReactions, "post shouldn't have reactions any more") diff --git a/store/storetest/session_store.go b/store/storetest/session_store.go index bca6f5f07bf..62ab216b46e 100644 --- a/store/storetest/session_store.go +++ b/store/storetest/session_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -65,7 +66,7 @@ func testSessionGet(t *testing.T, ss store.Store) { s3, err = ss.Session().Save(s3) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.Equal(t, session.Id, s1.Id, "should match") @@ -110,14 +111,14 @@ func testSessionRemove(t *testing.T, ss store.Store) { s1, err := ss.Session().Save(s1) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.Equal(t, session.Id, s1.Id, "should match") removeErr := ss.Session().Remove(s1.Id) require.NoError(t, removeErr) - _, err = ss.Session().Get(s1.Id) + _, err = ss.Session().Get(context.Background(), s1.Id) require.Error(t, err, "should have been removed") } @@ -128,14 +129,14 @@ func testSessionRemoveAll(t *testing.T, ss store.Store) { s1, err := ss.Session().Save(s1) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.Equal(t, session.Id, s1.Id, "should match") removeErr := ss.Session().RemoveAllSessions() require.NoError(t, removeErr) - _, err = ss.Session().Get(s1.Id) + _, err = ss.Session().Get(context.Background(), s1.Id) require.Error(t, err, "should have been removed") } @@ -146,14 +147,14 @@ func testSessionRemoveByUser(t *testing.T, ss store.Store) { s1, err := ss.Session().Save(s1) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.Equal(t, session.Id, s1.Id, "should match") deleteErr := ss.Session().PermanentDeleteSessionsByUser(s1.UserId) require.NoError(t, deleteErr) - _, err = ss.Session().Get(s1.Id) + _, err = ss.Session().Get(context.Background(), s1.Id) require.Error(t, err, "should have been removed") } @@ -164,14 +165,14 @@ func testSessionRemoveToken(t *testing.T, ss store.Store) { s1, err := ss.Session().Save(s1) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.Equal(t, session.Id, s1.Id, "should match") removeErr := ss.Session().Remove(s1.Token) require.NoError(t, removeErr) - _, err = ss.Session().Get(s1.Id) + _, err = ss.Session().Get(context.Background(), s1.Id) require.Error(t, err, "should have been removed") data, err := ss.Session().GetSessions(s1.UserId) @@ -229,7 +230,7 @@ func testSessionStoreUpdateExpiresAt(t *testing.T, ss store.Store) { err = ss.Session().UpdateExpiresAt(s1.Id, 1234567890) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.EqualValues(t, session.ExpiresAt, 1234567890, "ExpiresAt not updated correctly") } @@ -244,7 +245,7 @@ func testSessionStoreUpdateLastActivityAt(t *testing.T, ss store.Store) { err = ss.Session().UpdateLastActivityAt(s1.Id, 1234567890) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.EqualValues(t, session.LastActivityAt, 1234567890, "LastActivityAt not updated correctly") } @@ -295,16 +296,16 @@ func testSessionCleanup(t *testing.T, ss store.Store) { ss.Session().Cleanup(now, 1) - _, err = ss.Session().Get(s1.Id) + _, err = ss.Session().Get(context.Background(), s1.Id) assert.NoError(t, err) - _, err = ss.Session().Get(s2.Id) + _, err = ss.Session().Get(context.Background(), s2.Id) assert.NoError(t, err) - _, err = ss.Session().Get(s3.Id) + _, err = ss.Session().Get(context.Background(), s3.Id) assert.Error(t, err) - _, err = ss.Session().Get(s4.Id) + _, err = ss.Session().Get(context.Background(), s4.Id) assert.Error(t, err) removeErr := ss.Session().Remove(s1.Id) @@ -377,19 +378,19 @@ func testUpdateExpiredNotify(t *testing.T, ss store.Store) { s1, err := ss.Session().Save(s1) require.NoError(t, err) - session, err := ss.Session().Get(s1.Id) + session, err := ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.False(t, session.ExpiredNotify) err = ss.Session().UpdateExpiredNotify(session.Id, true) require.NoError(t, err) - session, err = ss.Session().Get(s1.Id) + session, err = ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.True(t, session.ExpiredNotify) err = ss.Session().UpdateExpiredNotify(session.Id, false) require.NoError(t, err) - session, err = ss.Session().Get(s1.Id) + session, err = ss.Session().Get(context.Background(), s1.Id) require.NoError(t, err) require.False(t, session.ExpiredNotify) } diff --git a/store/storetest/settings.go b/store/storetest/settings.go index c0201c5f713..b800f1acdfb 100644 --- a/store/storetest/settings.go +++ b/store/storetest/settings.go @@ -20,9 +20,10 @@ import ( ) const ( - defaultMysqlDSN = "mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s&multiStatements=true" - defaultPostgresqlDSN = "postgres://mmuser:mostest@localhost:5432/mattermost_test?sslmode=disable&connect_timeout=10" - defaultMysqlRootPWD = "mostest" + defaultMysqlDSN = "mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s&multiStatements=true" + defaultPostgresqlDSN = "postgres://mmuser:mostest@localhost:5432/mattermost_test?sslmode=disable&connect_timeout=10" + defaultMysqlRootPWD = "mostest" + defaultMysqlReplicaDSN = "root:mostest@tcp(localhost:3307)/mattermost_test?charset=utf8mb4,utf8\u0026readTimeout=30s" ) func getEnv(name, defaultValue string) string { @@ -48,7 +49,7 @@ func log(message string) { // MySQLSettings returns the database settings to connect to the MySQL unittesting database. // The database name is generated randomly and must be created before use. -func MySQLSettings() *model.SqlSettings { +func MySQLSettings(withReplica bool) *model.SqlSettings { dsn := getEnv("TEST_DATABASE_MYSQL_DSN", defaultMysqlDSN) cfg, err := mysql.ParseDSN(dsn) if err != nil { @@ -57,7 +58,13 @@ func MySQLSettings() *model.SqlSettings { cfg.DBName = "db" + model.NewId() - return databaseSettings("mysql", cfg.FormatDSN()) + mySQLSettings := databaseSettings("mysql", cfg.FormatDSN()) + + if withReplica { + mySQLSettings.DataSourceReplicas = []string{getEnv("TEST_DATABASE_MYSQL_REPLICA_DSN", defaultMysqlReplicaDSN)} + } + + return mySQLSettings } // PostgresSQLSettings returns the database settings to connect to the PostgreSQL unittesting database. @@ -174,15 +181,29 @@ func execAsRoot(settings *model.SqlSettings, sqlCommand string) error { return nil } +func replaceMySQLDatabaseName(dsn, newDBName string) string { + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + panic("failed to parse dsn " + dsn + ": " + err.Error()) + } + cfg.DBName = newDBName + return cfg.FormatDSN() +} + // MakeSqlSettings creates a randomly named database and returns the corresponding sql settings -func MakeSqlSettings(driver string) *model.SqlSettings { +func MakeSqlSettings(driver string, withReplica bool) *model.SqlSettings { var settings *model.SqlSettings var dbName string switch driver { case model.DATABASE_DRIVER_MYSQL: - settings = MySQLSettings() + settings = MySQLSettings(withReplica) dbName = mySQLDSNDatabase(*settings.DataSource) + newDSRs := []string{} + for _, dataSource := range settings.DataSourceReplicas { + newDSRs = append(newDSRs, replaceMySQLDatabaseName(dataSource, dbName)) + } + settings.DataSourceReplicas = newDSRs case model.DATABASE_DRIVER_POSTGRES: settings = PostgreSQLSettings() dbName = postgreSQLDSNDatabase(*settings.DataSource) diff --git a/store/storetest/team_store.go b/store/storetest/team_store.go index 986d9216519..21a65f51240 100644 --- a/store/storetest/team_store.go +++ b/store/storetest/team_store.go @@ -2855,17 +2855,17 @@ func testGetTeamMember(t *testing.T, ss store.Store) { require.NoError(t, nErr) var rm1 *model.TeamMember - rm1, err := ss.Team().GetMember(m1.TeamId, m1.UserId) + rm1, err := ss.Team().GetMember(context.Background(), m1.TeamId, m1.UserId) require.NoError(t, err) require.Equal(t, rm1.TeamId, m1.TeamId, "bad team id") require.Equal(t, rm1.UserId, m1.UserId, "bad user id") - _, err = ss.Team().GetMember(m1.TeamId, "") + _, err = ss.Team().GetMember(context.Background(), m1.TeamId, "") require.Error(t, err, "empty user id - should have failed") - _, err = ss.Team().GetMember("", m1.UserId) + _, err = ss.Team().GetMember(context.Background(), "", m1.UserId) require.Error(t, err, "empty team id - should have failed") // Test with a custom team scheme. @@ -2895,7 +2895,7 @@ func testGetTeamMember(t *testing.T, ss store.Store) { _, nErr = ss.Team().SaveMember(m2, -1) require.NoError(t, nErr) - m3, err := ss.Team().GetMember(m2.TeamId, m2.UserId) + m3, err := ss.Team().GetMember(context.Background(), m2.TeamId, m2.UserId) require.NoError(t, err) t.Log(m3) @@ -2905,7 +2905,7 @@ func testGetTeamMember(t *testing.T, ss store.Store) { _, nErr = ss.Team().SaveMember(m4, -1) require.NoError(t, nErr) - m5, err := ss.Team().GetMember(m4.TeamId, m4.UserId) + m5, err := ss.Team().GetMember(context.Background(), m4.TeamId, m4.UserId) require.NoError(t, err) assert.Equal(t, s2.DefaultTeamGuestRole, m5.Roles) @@ -3216,19 +3216,19 @@ func testTeamStoreMigrateTeamMembers(t *testing.T, ss store.Store) { } } - tm1b, err := ss.Team().GetMember(tm1.TeamId, tm1.UserId) + tm1b, err := ss.Team().GetMember(context.Background(), tm1.TeamId, tm1.UserId) assert.NoError(t, err) assert.Equal(t, "", tm1b.ExplicitRoles) assert.True(t, tm1b.SchemeUser) assert.True(t, tm1b.SchemeAdmin) - tm2b, err := ss.Team().GetMember(tm2.TeamId, tm2.UserId) + tm2b, err := ss.Team().GetMember(context.Background(), tm2.TeamId, tm2.UserId) assert.NoError(t, err) assert.Equal(t, "", tm2b.ExplicitRoles) assert.True(t, tm2b.SchemeUser) assert.False(t, tm2b.SchemeAdmin) - tm3b, err := ss.Team().GetMember(tm3.TeamId, tm3.UserId) + tm3b, err := ss.Team().GetMember(context.Background(), tm3.TeamId, tm3.UserId) assert.NoError(t, err) assert.Equal(t, "something_else", tm3b.ExplicitRoles) assert.False(t, tm3b.SchemeUser) @@ -3309,19 +3309,19 @@ func testTeamStoreClearAllCustomRoleAssignments(t *testing.T, ss store.Store) { require.NoError(t, (ss.Team().ClearAllCustomRoleAssignments())) - r1, err := ss.Team().GetMember(m1.TeamId, m1.UserId) + r1, err := ss.Team().GetMember(context.Background(), m1.TeamId, m1.UserId) require.NoError(t, err) assert.Equal(t, m1.ExplicitRoles, r1.Roles) - r2, err := ss.Team().GetMember(m2.TeamId, m2.UserId) + r2, err := ss.Team().GetMember(context.Background(), m2.TeamId, m2.UserId) require.NoError(t, err) assert.Equal(t, "team_user team_admin", r2.Roles) - r3, err := ss.Team().GetMember(m3.TeamId, m3.UserId) + r3, err := ss.Team().GetMember(context.Background(), m3.TeamId, m3.UserId) require.NoError(t, err) assert.Equal(t, m3.ExplicitRoles, r3.Roles) - r4, err := ss.Team().GetMember(m4.TeamId, m4.UserId) + r4, err := ss.Team().GetMember(context.Background(), m4.TeamId, m4.UserId) require.NoError(t, err) assert.Equal(t, "", r4.Roles) } diff --git a/store/storetest/thread_store.go b/store/storetest/thread_store.go index 948d05ecd0a..91ec4154de2 100644 --- a/store/storetest/thread_store.go +++ b/store/storetest/thread_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "testing" "time" @@ -69,7 +70,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) { newPosts, errIdx, err3 := ss.Post().SaveMultiple([]*model.Post{&o2, &o3, &o4}) - olist, _ := ss.Post().Get(otmp.Id, true, false, false) + olist, _ := ss.Post().Get(context.Background(), otmp.Id, true, false, false) o1 := olist.Posts[olist.Order[0]] newPosts = append([]*model.Post{o1}, newPosts...) diff --git a/store/storetest/user_access_token_store.go b/store/storetest/user_access_token_store.go index ef02c3af7cf..70bc6a30f7c 100644 --- a/store/storetest/user_access_token_store.go +++ b/store/storetest/user_access_token_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -57,7 +58,7 @@ func testUserAccessTokenSaveGetDelete(t *testing.T, ss store.Store) { nErr = ss.UserAccessToken().Delete(uat.Id) require.NoError(t, nErr) - _, err = ss.Session().Get(s1.Token) + _, err = ss.Session().Get(context.Background(), s1.Token) require.Error(t, err, "should error - session should be deleted") _, nErr = ss.UserAccessToken().GetByToken(s1.Token) @@ -76,7 +77,7 @@ func testUserAccessTokenSaveGetDelete(t *testing.T, ss store.Store) { nErr = ss.UserAccessToken().DeleteAllForUser(uat.UserId) require.NoError(t, nErr) - _, err = ss.Session().Get(s2.Token) + _, err = ss.Session().Get(context.Background(), s2.Token) require.Error(t, err, "should error - session should be deleted") _, nErr = ss.UserAccessToken().GetByToken(s2.Token) @@ -103,7 +104,7 @@ func testUserAccessTokenDisableEnable(t *testing.T, ss store.Store) { nErr = ss.UserAccessToken().UpdateTokenDisable(uat.Id) require.NoError(t, nErr) - _, err = ss.Session().Get(s1.Token) + _, err = ss.Session().Get(context.Background(), s1.Token) require.Error(t, err, "should error - session should be deleted") s2 := &model.Session{} diff --git a/store/storetest/user_store.go b/store/storetest/user_store.go index a857ab5db72..c88e9449683 100644 --- a/store/storetest/user_store.go +++ b/store/storetest/user_store.go @@ -4841,7 +4841,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.Equal(t, "system_user", updatedUser.Roles) require.True(t, user.UpdateAt < updatedUser.UpdateAt) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -4886,7 +4886,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user system_admin", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -4942,7 +4942,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -4982,7 +4982,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -5027,7 +5027,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user custom_role", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -5093,7 +5093,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId1, user1.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId1, user1.Id) require.NoError(t, nErr) require.False(t, updatedTeamMember.SchemeGuest) require.True(t, updatedTeamMember.SchemeUser) @@ -5107,7 +5107,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest", notUpdatedUser.Roles) - notUpdatedTeamMember, nErr := ss.Team().GetMember(teamId2, user2.Id) + notUpdatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId2, user2.Id) require.NoError(t, nErr) require.True(t, notUpdatedTeamMember.SchemeGuest) require.False(t, notUpdatedTeamMember.SchemeUser) @@ -5154,7 +5154,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.Equal(t, "system_guest", updatedUser.Roles) require.True(t, user.UpdateAt < updatedUser.UpdateAt) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, updatedUser.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, updatedUser.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5197,7 +5197,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5249,7 +5249,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5287,7 +5287,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5330,7 +5330,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest custom_role", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId, user.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId, user.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5394,7 +5394,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_guest", updatedUser.Roles) - updatedTeamMember, nErr := ss.Team().GetMember(teamId1, user1.Id) + updatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId1, user1.Id) require.NoError(t, nErr) require.True(t, updatedTeamMember.SchemeGuest) require.False(t, updatedTeamMember.SchemeUser) @@ -5408,7 +5408,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) { require.NoError(t, err) require.Equal(t, "system_user", notUpdatedUser.Roles) - notUpdatedTeamMember, nErr := ss.Team().GetMember(teamId2, user2.Id) + notUpdatedTeamMember, nErr := ss.Team().GetMember(context.Background(), teamId2, user2.Id) require.NoError(t, nErr) require.False(t, notUpdatedTeamMember.SchemeGuest) require.True(t, notUpdatedTeamMember.SchemeUser) diff --git a/store/timerlayer/timerlayer.go b/store/timerlayer/timerlayer.go index a9140fd3327..aecfc7d80fc 100644 --- a/store/timerlayer/timerlayer.go +++ b/store/timerlayer/timerlayer.go @@ -4476,10 +4476,10 @@ func (s *TimerLayerPostStore) Delete(postID string, time int64, deleteByID strin return err } -func (s *TimerLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { +func (s *TimerLayerPostStore) Get(ctx context.Context, id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { start := timemodule.Now() - result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) + result, err := s.PostStore.Get(ctx, id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -5626,10 +5626,10 @@ func (s *TimerLayerSessionStore) Cleanup(expiryTime int64, batchSize int64) { } } -func (s *TimerLayerSessionStore) Get(sessionIDOrToken string) (*model.Session, error) { +func (s *TimerLayerSessionStore) Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) { start := timemodule.Now() - result, err := s.SessionStore.Get(sessionIDOrToken) + result, err := s.SessionStore.Get(ctx, sessionIDOrToken) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -6409,10 +6409,10 @@ func (s *TimerLayerTeamStore) GetChannelUnreadsForTeam(teamID string, userId str return result, err } -func (s *TimerLayerTeamStore) GetMember(teamID string, userId string) (*model.TeamMember, error) { +func (s *TimerLayerTeamStore) GetMember(ctx context.Context, teamID string, userId string) (*model.TeamMember, error) { start := timemodule.Now() - result, err := s.TeamStore.GetMember(teamID, userId) + result, err := s.TeamStore.GetMember(ctx, teamID, userId) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { diff --git a/testlib/helper.go b/testlib/helper.go index 0a0ab03fade..3da6d203a14 100644 --- a/testlib/helper.go +++ b/testlib/helper.go @@ -33,11 +33,13 @@ type MainHelper struct { status int testResourcePath string + replicas []string } type HelperOptions struct { EnableStore bool EnableResources bool + WithReadReplica bool } func NewMainHelper() *MainHelper { @@ -65,7 +67,7 @@ func NewMainHelperWithOptions(options *HelperOptions) *MainHelper { if options != nil { if options.EnableStore && !testing.Short() { - mainHelper.setupStore() + mainHelper.setupStore(options.WithReadReplica) } if options.EnableResources { @@ -99,13 +101,14 @@ func (h *MainHelper) Main(m *testing.M) { h.status = m.Run() } -func (h *MainHelper) setupStore() { +func (h *MainHelper) setupStore(withReadReplica bool) { driverName := os.Getenv("MM_SQLSETTINGS_DRIVERNAME") if driverName == "" { driverName = model.DATABASE_DRIVER_POSTGRES } - h.Settings = storetest.MakeSqlSettings(driverName) + h.Settings = storetest.MakeSqlSettings(driverName, withReadReplica) + h.replicas = h.Settings.DataSourceReplicas config := &model.Config{} config.SetDefaults() @@ -118,6 +121,26 @@ func (h *MainHelper) setupStore() { }, h.SearchEngine, config) } +func (h *MainHelper) ToggleReplicasOff() { + if h.SQLStore.GetLicense() == nil { + panic("expecting a license to use this") + } + h.Settings.DataSourceReplicas = []string{} + lic := h.SQLStore.GetLicense() + h.SQLStore = sqlstore.New(*h.Settings, nil) + h.SQLStore.UpdateLicense(lic) +} + +func (h *MainHelper) ToggleReplicasOn() { + if h.SQLStore.GetLicense() == nil { + panic("expecting a license to use this") + } + h.Settings.DataSourceReplicas = h.replicas + lic := h.SQLStore.GetLicense() + h.SQLStore = sqlstore.New(*h.Settings, nil) + h.SQLStore.UpdateLicense(lic) +} + func (h *MainHelper) setupResources() { var err error h.testResourcePath, err = SetupTestResources() @@ -234,3 +257,36 @@ func (h *MainHelper) GetSearchEngine() *searchengine.Broker { return h.SearchEngine } + +func (h *MainHelper) SetReplicationLagForTesting(seconds int) error { + if dn := h.SQLStore.DriverName(); dn != model.DATABASE_DRIVER_MYSQL { + return fmt.Errorf("method not implemented for %q database driver, only %q is supported", dn, model.DATABASE_DRIVER_MYSQL) + } + + err := h.execOnEachReplica("STOP SLAVE SQL_THREAD FOR CHANNEL ''") + if err != nil { + return err + } + + err = h.execOnEachReplica(fmt.Sprintf("CHANGE MASTER TO MASTER_DELAY = %d", seconds)) + if err != nil { + return err + } + + err = h.execOnEachReplica("START SLAVE SQL_THREAD FOR CHANNEL ''") + if err != nil { + return err + } + + return nil +} + +func (h *MainHelper) execOnEachReplica(query string, args ...interface{}) error { + for _, replica := range h.SQLStore.Replicas { + _, err := replica.Exec(query, args...) + if err != nil { + return err + } + } + return nil +} diff --git a/tests/test-data.ldif b/tests/test-data.ldif index 44d2bb6f5ed..3f6cffdce97 100644 --- a/tests/test-data.ldif +++ b/tests/test-data.ldif @@ -129,6 +129,22 @@ cn: Board3 mail: success+boardthree@simulator.amazonses.com userPassword: Password1 +dn: uid=firstloginuser.one,ou=testusers,dc=mm,dc=test,dc=com +changetype: add +objectclass: iNetOrgPerson +sn: User +cn: FirstLogin1 +mail: success+firstloginuser.one@simulator.amazonses.com +userPassword: Password1 + +dn: uid=firstloginuser.two,ou=testusers,dc=mm,dc=test,dc=com +changetype: add +objectclass: iNetOrgPerson +sn: User +cn: FirstLogin2 +mail: success+firstloginuser.two@simulator.amazonses.com +userPassword: Password1 + dn: ou=testgroups,dc=mm,dc=test,dc=com changetype: add objectclass: organizationalunit @@ -209,4 +225,10 @@ changetype: add objectclass: groupOfUniqueNames uniqueMember: uid=dev-ops.one,ou=testusers,dc=mm,dc=test,dc=com uniqueMember: cn=team-one,ou=testgroups,dc=mm,dc=test,dc=com -uniqueMember: cn=team-two,ou=testgroups,dc=mm,dc=test,dc=com \ No newline at end of file +uniqueMember: cn=team-two,ou=testgroups,dc=mm,dc=test,dc=com + +dn: cn=firstlogingroup,ou=testgroups,dc=mm,dc=test,dc=com +changetype: add +objectclass: groupOfUniqueNames +uniqueMember: uid=firstloginuser.one,ou=testusers,dc=mm,dc=test,dc=com +uniqueMember: uid=firstloginuser.two,ou=testusers,dc=mm,dc=test,dc=com \ No newline at end of file