diff --git a/server/channels/api4/system.go b/server/channels/api4/system.go index 9b85abcfb83..c00a0c0a480 100644 --- a/server/channels/api4/system.go +++ b/server/channels/api4/system.go @@ -319,7 +319,11 @@ func invalidateCaches(c *Context, w http.ResponseWriter, r *http.Request) { return } - c.App.Srv().InvalidateAllCaches() + appErr := c.App.Srv().InvalidateAllCaches() + if appErr != nil { + c.Err = appErr + return + } auditRec.Success() diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index 6bc2c214b01..e2e65cf7acc 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -3909,7 +3909,8 @@ func TestLoginWithLag(t *testing.T) { _, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) require.NoError(t, err) - th.App.Srv().InvalidateAllCaches() + appErr = th.App.Srv().InvalidateAllCaches() + require.Nil(t, appErr) session, appErr := th.App.GetSession(th.Client.AuthToken) require.Nil(t, appErr) diff --git a/server/channels/app/admin.go b/server/channels/app/admin.go index b5ab1a9010f..95afdff8a6a 100644 --- a/server/channels/app/admin.go +++ b/server/channels/app/admin.go @@ -137,8 +137,8 @@ func (a *App) GetClusterStatus() []*model.ClusterInfo { return infos } -func (s *Server) InvalidateAllCaches() { - s.platform.InvalidateAllCaches() +func (s *Server) InvalidateAllCaches() *model.AppError { + return s.platform.InvalidateAllCaches() } func (s *Server) InvalidateAllCachesSkipSend() { diff --git a/server/channels/app/license.go b/server/channels/app/license.go index 33efbdd5e3c..c2f879e4758 100644 --- a/server/channels/app/license.go +++ b/server/channels/app/license.go @@ -131,6 +131,10 @@ func (s *Server) License() *model.License { return s.platform.License() } +func (s *Server) LoadLicense() { + s.platform.LoadLicense() +} + func (s *Server) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { return s.platform.SaveLicense(licenseBytes) } diff --git a/server/channels/app/license_test.go b/server/channels/app/license_test.go new file mode 100644 index 00000000000..26b4b58b70c --- /dev/null +++ b/server/channels/app/license_test.go @@ -0,0 +1,109 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" +) + +func TestLoadLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + th.App.Srv().LoadLicense() + require.Nil(t, th.App.Srv().License(), "shouldn't have a valid license") +} + +func TestSaveLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + b1 := []byte("junk") + + _, err := th.App.Srv().SaveLicense(b1) + require.NotNil(t, err, "shouldn't have saved license") +} + +func TestRemoveLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + err := th.App.Srv().RemoveLicense() + require.Nil(t, err, "should have removed license") +} + +func TestSetLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + l1 := &model.License{} + l1.Features = &model.Features{} + l1.Customer = &model.Customer{} + l1.StartsAt = model.GetMillis() - 1000 + l1.ExpiresAt = model.GetMillis() + 100000 + ok := th.App.Srv().SetLicense(l1) + require.True(t, ok, "license should have worked") + + l3 := &model.License{} + l3.Features = &model.Features{} + l3.Customer = &model.Customer{} + l3.StartsAt = model.GetMillis() + 10000 + l3.ExpiresAt = model.GetMillis() + 100000 + ok = th.App.Srv().SetLicense(l3) + require.True(t, ok, "license should have passed") +} + +func TestGetSanitizedClientLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + setLicense(th, nil) + + m := th.App.Srv().GetSanitizedClientLicense() + + _, ok := m["Name"] + assert.False(t, ok) + _, ok = m["SkuName"] + assert.False(t, ok) +} + +func TestGenerateRenewalToken(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + t.Run("renewal token generated correctly", func(t *testing.T) { + setLicense(th, nil) + token, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration) + require.Nil(t, appErr) + require.NotEmpty(t, token) + }) + + t.Run("return error if there is no active license", func(t *testing.T) { + th.App.Srv().SetLicense(nil) + _, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration) + require.NotNil(t, appErr) + }) +} + +func setLicense(th *TestHelper, customer *model.Customer) { + l1 := &model.License{} + l1.Features = &model.Features{} + if customer != nil { + l1.Customer = customer + } else { + l1.Customer = &model.Customer{} + l1.Customer.Name = "TestName" + l1.Customer.Email = "test@example.com" + } + l1.SkuName = "SKU NAME" + l1.SkuShortName = "SKU SHORT NAME" + l1.StartsAt = model.GetMillis() - 1000 + l1.ExpiresAt = model.GetMillis() + 100000 + th.App.Srv().SetLicense(l1) +} diff --git a/server/channels/app/notification_test.go b/server/channels/app/notification_test.go index 9dbf635996a..6b4f638f6dd 100644 --- a/server/channels/app/notification_test.go +++ b/server/channels/app/notification_test.go @@ -99,7 +99,8 @@ func TestSendNotifications(t *testing.T) { _, appErr = th.App.UpdateActive(th.Context, th.BasicUser2, false) require.Nil(t, appErr) - th.App.Srv().InvalidateAllCaches() + appErr = th.App.Srv().InvalidateAllCaches() + require.Nil(t, appErr) post3, appErr := th.App.CreatePostMissingChannel(th.Context, &model.Post{ UserId: th.BasicUser.Id, diff --git a/server/channels/app/oauth.go b/server/channels/app/oauth.go index 4c287cd9b2b..1be59084079 100644 --- a/server/channels/app/oauth.go +++ b/server/channels/app/oauth.go @@ -113,7 +113,9 @@ func (a *App) DeleteOAuthApp(appID string) *model.AppError { return model.NewAppError("DeleteOAuthApp", "app.oauth.delete_app.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } - a.Srv().InvalidateAllCaches() + if err := a.Srv().InvalidateAllCaches(); err != nil { + mlog.Warn("error in invalidating cache", mlog.Err(err)) + } return nil } diff --git a/server/channels/app/platform/cluster_handlers.go b/server/channels/app/platform/cluster_handlers.go index 2d50fd15165..a45809085b6 100644 --- a/server/channels/app/platform/cluster_handlers.go +++ b/server/channels/app/platform/cluster_handlers.go @@ -17,7 +17,6 @@ func (ps *PlatformService) RegisterClusterHandlers() { ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventPublish, ps.ClusterPublishHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventUpdateStatus, ps.ClusterUpdateStatusHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateAllCaches, ps.ClusterInvalidateAllCachesHandler) - ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventLoadLicense, ps.LoadLicenseClusterHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelMembersNotifyProps, ps.clusterInvalidateCacheForChannelMembersNotifyPropHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelByName, ps.clusterInvalidateCacheForChannelByNameHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForUser, ps.clusterInvalidateCacheForUserHandler) @@ -155,27 +154,10 @@ func (ps *PlatformService) InvalidateAllCachesSkipSend() { ps.Store.Webhook().ClearCaches() linkCache.Purge() + ps.LoadLicense() } -func (ps *PlatformService) LoadLicenseClusterHandler(_ *model.ClusterMessage) { - ps.loadLicense() -} - -func (ps *PlatformService) TriggerLoadLicense() { - ps.loadLicense() - - if ps.clusterIFace != nil { - msg := &model.ClusterMessage{ - Event: model.ClusterEventLoadLicense, - SendType: model.ClusterSendReliable, - WaitForAllToSend: true, - } - - ps.clusterIFace.SendClusterMessage(msg) - } -} - -func (ps *PlatformService) InvalidateAllCaches() { +func (ps *PlatformService) InvalidateAllCaches() *model.AppError { ps.InvalidateAllCachesSkipSend() if ps.clusterIFace != nil { @@ -188,4 +170,6 @@ func (ps *PlatformService) InvalidateAllCaches() { ps.clusterIFace.SendClusterMessage(msg) } + + return nil } diff --git a/server/channels/app/platform/license.go b/server/channels/app/platform/license.go index dbe7afb27ff..d3a1b0ecb5c 100644 --- a/server/channels/app/platform/license.go +++ b/server/channels/app/platform/license.go @@ -46,7 +46,7 @@ func (ps *PlatformService) License() *model.License { return ps.licenseValue.Load() } -func (ps *PlatformService) loadLicense() { +func (ps *PlatformService) LoadLicense() { // ENV var overrides all other sources of license. licenseStr := os.Getenv(LicenseEnv) if licenseStr != "" { @@ -326,6 +326,9 @@ func (ps *PlatformService) RequestTrialLicense(trialRequest *model.TrialLicenseR return err } + ps.ReloadConfig() + ps.InvalidateAllCaches() + return nil } diff --git a/server/channels/app/platform/license_test.go b/server/channels/app/platform/license_test.go index cbc15b507a2..3f1af98ad1d 100644 --- a/server/channels/app/platform/license_test.go +++ b/server/channels/app/platform/license_test.go @@ -12,6 +12,14 @@ import ( "github.com/mattermost/mattermost/server/public/model" ) +func TestLoadLicense(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + th.Service.LoadLicense() + require.Nil(t, th.Service.License(), "shouldn't have a valid license") +} + func TestSaveLicense(t *testing.T) { th := Setup(t) defer th.TearDown() diff --git a/server/channels/app/platform/service.go b/server/channels/app/platform/service.go index 9074e5258cf..8869f5ea408 100644 --- a/server/channels/app/platform/service.go +++ b/server/channels/app/platform/service.go @@ -292,7 +292,7 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) { // Step 7: Init License if model.BuildEnterpriseReady == "true" { - ps.TriggerLoadLicense() + ps.LoadLicense() } // Step 8: Init Metrics Server depends on step 6 (store) and 7 (license) @@ -353,7 +353,9 @@ func (ps *PlatformService) Start() error { message := model.NewWebSocketEvent(model.WebsocketEventConfigChanged, "", "", "", nil, "") message.Add("config", ps.ClientConfigWithComputed()) - ps.Publish(message) + ps.Go(func() { + ps.Publish(message) + }) if err := ps.ReconfigureLogger(); err != nil { mlog.Error("Error re-configuring logging after config change", mlog.Err(err)) @@ -366,7 +368,9 @@ func (ps *PlatformService) Start() error { message := model.NewWebSocketEvent(model.WebsocketEventLicenseChanged, "", "", "", nil, "") message.Add("license", ps.GetSanitizedClientLicense()) - ps.Publish(message) + ps.Go(func() { + ps.Publish(message) + }) }) return nil diff --git a/server/channels/app/server.go b/server/channels/app/server.go index d59447997f6..a93b79bbcc0 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -209,6 +209,7 @@ func NewServer(options ...Option) (*Server, error) { // Depends on step 1 (s.Platform must be non-nil) s.initEnterprise() + // Needed to run before loading license. s.userService, err = users.New(users.ServiceConfig{ UserStore: s.Store().User(), SessionStore: s.Store().Session(), @@ -222,6 +223,11 @@ func NewServer(options ...Option) (*Server, error) { return nil, errors.Wrapf(err, "unable to create users service") } + if model.BuildEnterpriseReady == "true" { + // Dependent on user service + s.LoadLicense() + } + s.licenseWrapper = &licenseWrapper{ srv: s, } @@ -1378,6 +1384,8 @@ func (s *Server) sendLicenseUpForRenewalEmail(users map[string]*model.User, lice } func (s *Server) doLicenseExpirationCheck() { + s.LoadLicense() + // This takes care of a rare edge case reported here https://mattermost.atlassian.net/browse/MM-40962 // To reproduce that case locally, attach a license to a server that was started with enterprise enabled // Then restart using BUILD_ENTERPRISE=false make restart-server to enter Team Edition @@ -1387,6 +1395,7 @@ func (s *Server) doLicenseExpirationCheck() { } license := s.License() + if license == nil { mlog.Debug("License cannot be found.") return diff --git a/server/cmd/mattermost/commands/init.go b/server/cmd/mattermost/commands/init.go index 5101aa99263..77f40c63868 100644 --- a/server/cmd/mattermost/commands/init.go +++ b/server/cmd/mattermost/commands/init.go @@ -49,6 +49,10 @@ func initDBCommandContext(configDSN string, readOnlyConfigStore bool, options .. a := app.New(app.ServerConnector(s.Channels())) + if model.BuildEnterpriseReady == "true" { + a.Srv().LoadLicense() + } + return a, nil } diff --git a/server/cmd/mattermost/commands/jobserver.go b/server/cmd/mattermost/commands/jobserver.go index 32895a1492c..18a544b4651 100644 --- a/server/cmd/mattermost/commands/jobserver.go +++ b/server/cmd/mattermost/commands/jobserver.go @@ -41,6 +41,8 @@ func jobserverCmdF(command *cobra.Command, args []string) error { } defer a.Srv().Shutdown() + a.Srv().LoadLicense() + // Run jobs mlog.Info("Starting Mattermost job server") defer mlog.Info("Stopped Mattermost job server") diff --git a/server/public/model/cluster_message.go b/server/public/model/cluster_message.go index 19f5e4079ed..6ff912f9ac8 100644 --- a/server/public/model/cluster_message.go +++ b/server/public/model/cluster_message.go @@ -9,7 +9,6 @@ const ( ClusterEventPublish ClusterEvent = "publish" ClusterEventUpdateStatus ClusterEvent = "update_status" ClusterEventInvalidateAllCaches ClusterEvent = "inv_all_caches" - ClusterEventLoadLicense ClusterEvent = "load_license" ClusterEventInvalidateCacheForReactions ClusterEvent = "inv_reactions" ClusterEventInvalidateCacheForChannelMembersNotifyProps ClusterEvent = "inv_channel_members_notify_props" ClusterEventInvalidateCacheForChannelByName ClusterEvent = "inv_channel_name"