diff --git a/server/channels/api4/saml_test.go b/server/channels/api4/saml_test.go index ac6bf972f5d..a3a69804c80 100644 --- a/server/channels/api4/saml_test.go +++ b/server/channels/api4/saml_test.go @@ -60,7 +60,7 @@ func TestSamlResetId(t *testing.T) { th.App.Channels().Saml = &mocks.SamlInterface{} user := th.BasicUser - _, appErr := th.App.UpdateUserAuth(user.Id, &model.UserAuth{ + _, appErr := th.App.UpdateUserAuth(nil, user.Id, &model.UserAuth{ AuthData: model.NewString(model.NewId()), AuthService: model.UserAuthServiceSaml, }) diff --git a/server/channels/api4/user.go b/server/channels/api4/user.go index f0429d8de85..ecb33f47a00 100644 --- a/server/channels/api4/user.go +++ b/server/channels/api4/user.go @@ -1609,7 +1609,7 @@ func updateUserAuth(c *Context, w http.ResponseWriter, r *http.Request) { auditRec.AddEventPriorState(user) } - user, err := c.App.UpdateUserAuth(c.Params.UserId, &userAuth) + user, err := c.App.UpdateUserAuth(c.AppContext, c.Params.UserId, &userAuth) if err != nil { c.Err = err return diff --git a/server/channels/app/app_iface.go b/server/channels/app/app_iface.go index b39c4d60d7b..9c996bc3d16 100644 --- a/server/channels/app/app_iface.go +++ b/server/channels/app/app_iface.go @@ -1153,7 +1153,7 @@ type AppIface interface { UpdateUser(c request.CTX, user *model.User, sendNotifications bool) (*model.User, *model.AppError) UpdateUserActive(c request.CTX, userID string, active bool) *model.AppError UpdateUserAsUser(c request.CTX, user *model.User, asAdmin bool) (*model.User, *model.AppError) - UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) + UpdateUserAuth(c request.CTX, userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) UpdateUserRoles(c request.CTX, userID string, newRoles string, sendWebSocketEvent bool) (*model.User, *model.AppError) UpdateUserRolesWithUser(c request.CTX, user *model.User, newRoles string, sendWebSocketEvent bool) (*model.User, *model.AppError) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (*model.FileInfo, *model.AppError) diff --git a/server/channels/app/opentracing/opentracing_layer.go b/server/channels/app/opentracing/opentracing_layer.go index 4e7d7d0a21a..36ddc10a308 100644 --- a/server/channels/app/opentracing/opentracing_layer.go +++ b/server/channels/app/opentracing/opentracing_layer.go @@ -18157,7 +18157,7 @@ func (a *OpenTracingAppLayer) UpdateUserAsUser(c request.CTX, user *model.User, return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { +func (a *OpenTracingAppLayer) UpdateUserAuth(c request.CTX, userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateUserAuth") @@ -18169,7 +18169,7 @@ func (a *OpenTracingAppLayer) UpdateUserAuth(userID string, userAuth *model.User }() defer span.Finish() - resultVar0, resultVar1 := a.app.UpdateUserAuth(userID, userAuth) + resultVar0, resultVar1 := a.app.UpdateUserAuth(c, userID, userAuth) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) diff --git a/server/channels/app/plugin_api.go b/server/channels/app/plugin_api.go index b5c6a467551..950e96aca65 100644 --- a/server/channels/app/plugin_api.go +++ b/server/channels/app/plugin_api.go @@ -324,6 +324,10 @@ func (api *PluginAPI) UpdateUser(user *model.User) (*model.User, *model.AppError return api.app.UpdateUser(api.ctx, user, true) } +func (api *PluginAPI) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { + return api.app.UpdateUserAuth(api.ctx, userID, userAuth) +} + func (api *PluginAPI) UpdateUserActive(userID string, active bool) *model.AppError { return api.app.UpdateUserActive(api.ctx, userID, active) } diff --git a/server/channels/app/plugin_api_tests/test_update_user_auth_plugin/main.go b/server/channels/app/plugin_api_tests/test_update_user_auth_plugin/main.go new file mode 100644 index 00000000000..44e5d6f8f31 --- /dev/null +++ b/server/channels/app/plugin_api_tests/test_update_user_auth_plugin/main.go @@ -0,0 +1,87 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package main + +import ( + "fmt" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/v8/channels/app/plugin_api_tests" +) + +type MyPlugin struct { + plugin.MattermostPlugin + configuration plugin_api_tests.BasicConfig +} + +func (p *MyPlugin) OnConfigurationChange() error { + if err := p.API.LoadPluginConfiguration(&p.configuration); err != nil { + return err + } + return nil +} + +func (p *MyPlugin) expectUserAuth(userID string, expectedUserAuth *model.UserAuth) error { + user, err := p.API.GetUser(p.configuration.BasicUserID) + if err != nil { + return err + } + if user.AuthService != expectedUserAuth.AuthService { + return fmt.Errorf("expected '%s' got '%s'", expectedUserAuth.AuthService, user.AuthService) + } + if user.AuthData == nil && expectedUserAuth.AuthData != nil { + return fmt.Errorf("expected '%s' got nil", *expectedUserAuth.AuthData) + } else if user.AuthData != nil && expectedUserAuth.AuthData == nil { + return fmt.Errorf("expected nil got '%s'", *user.AuthData) + } else if user.AuthData != nil && expectedUserAuth.AuthData != nil && *user.AuthData != *expectedUserAuth.AuthData { + return fmt.Errorf("expected '%s' got '%s'", *expectedUserAuth.AuthData, *user.AuthData) + } + + return nil +} + +func (p *MyPlugin) MessageWillBePosted(_ *plugin.Context, _ *model.Post) (*model.Post, string) { + // BasicUser2 should remain unchanged throughout + user, appErr := p.API.GetUser(p.configuration.BasicUser2Id) + if appErr != nil { + return nil, appErr.Error() + } + expectedUser2Auth := &model.UserAuth{ + AuthService: user.AuthService, + AuthData: user.AuthData, + } + + // Update BasicUser to SAML + expectedUserAuth := &model.UserAuth{ + AuthService: model.UserAuthServiceSaml, + AuthData: model.NewString("saml_auth_data"), + } + _, appErr = p.API.UpdateUserAuth(p.configuration.BasicUserID, expectedUserAuth) + if appErr != nil { + return nil, appErr.Error() + } + + p.expectUserAuth(p.configuration.BasicUserID, expectedUserAuth) + p.expectUserAuth(p.configuration.BasicUser2Id, expectedUser2Auth) + + // Update BasicUser to LDAP + expectedUserAuth = &model.UserAuth{ + AuthService: model.UserAuthServiceLdap, + AuthData: model.NewString("ldap_auth_data"), + } + _, err := p.API.UpdateUserAuth(p.configuration.BasicUserID, expectedUserAuth) + if err != nil { + return nil, err.Error() + } + + p.expectUserAuth(p.configuration.BasicUserID, expectedUserAuth) + p.expectUserAuth(p.configuration.BasicUser2Id, expectedUser2Auth) + + return nil, "OK" +} + +func main() { + plugin.ClientMain(&MyPlugin{}) +} diff --git a/server/channels/app/user.go b/server/channels/app/user.go index c6b1e81e432..a658bda8fa0 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -1109,7 +1109,7 @@ func (a *App) PatchUser(c request.CTX, userID string, patch *model.UserPatch, as return updatedUser, nil } -func (a *App) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { +func (a *App) UpdateUserAuth(c request.CTX, userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { if _, err := a.Srv().Store().User().UpdateAuthData(userID, userAuth.AuthService, userAuth.AuthData, "", false); err != nil { var invErr *store.ErrInvalidInput switch { @@ -1120,6 +1120,8 @@ func (a *App) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.Us } } + a.InvalidateCacheForUser(userID) + return userAuth, nil } diff --git a/server/cmd/mmctl/commands/user_e2e_test.go b/server/cmd/mmctl/commands/user_e2e_test.go index ec0fe7e8ae6..8fce57b9c64 100644 --- a/server/cmd/mmctl/commands/user_e2e_test.go +++ b/server/cmd/mmctl/commands/user_e2e_test.go @@ -1019,7 +1019,7 @@ func (s *MmctlE2ETestSuite) TestMigrateAuthCmd() { err := migrateAuthCmdF(c, cmd, []string{"ldap", "saml"}) s.Require().NoError(err) defer func() { - _, appErr := s.th.App.UpdateUserAuth(ldapUser.Id, &model.UserAuth{ + _, appErr := s.th.App.UpdateUserAuth(s.th.Context, ldapUser.Id, &model.UserAuth{ AuthData: model.NewString("test.user.1"), AuthService: model.UserAuthServiceLdap, }) @@ -1048,7 +1048,7 @@ func (s *MmctlE2ETestSuite) TestMigrateAuthCmd() { err := migrateAuthCmdF(c, cmd, []string{"saml", "ldap", "email"}) s.Require().NoError(err) defer func() { - _, appErr := s.th.App.UpdateUserAuth(samlUser.Id, &model.UserAuth{ + _, appErr := s.th.App.UpdateUserAuth(s.th.Context, samlUser.Id, &model.UserAuth{ AuthData: model.NewString("dev.one"), AuthService: model.UserAuthServiceSaml, }) diff --git a/server/public/plugin/api.go b/server/public/plugin/api.go index f7c96cb2bb6..a5e3533f071 100644 --- a/server/public/plugin/api.go +++ b/server/public/plugin/api.go @@ -1198,6 +1198,15 @@ type API interface { // // Minimum server version: 9.0 SendPushNotification(notification *model.PushNotification, userID string) *model.AppError + + // UpdateUserAuth updates a user's auth data. + // + // It is not currently possible to use this to set a user's auth to e-mail with a hashed + // password. It is meant to be used exclusively in setting a non-email auth service. + // + // @tag User + // Minimum server version: 9.3 + UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) } var handshake = plugin.HandshakeConfig{ diff --git a/server/public/plugin/api_timer_layer_generated.go b/server/public/plugin/api_timer_layer_generated.go index 4376fbbab4e..266ea5d20a7 100644 --- a/server/public/plugin/api_timer_layer_generated.go +++ b/server/public/plugin/api_timer_layer_generated.go @@ -1280,3 +1280,10 @@ func (api *apiTimerLayer) SendPushNotification(notification *model.PushNotificat api.recordTime(startTime, "SendPushNotification", _returnsA == nil) return _returnsA } + +func (api *apiTimerLayer) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { + startTime := timePkg.Now() + _returnsA, _returnsB := api.apiImpl.UpdateUserAuth(userID, userAuth) + api.recordTime(startTime, "UpdateUserAuth", _returnsB == nil) + return _returnsA, _returnsB +} diff --git a/server/public/plugin/client_rpc_generated.go b/server/public/plugin/client_rpc_generated.go index c91e2df258e..8f300fa74a5 100644 --- a/server/public/plugin/client_rpc_generated.go +++ b/server/public/plugin/client_rpc_generated.go @@ -5996,3 +5996,33 @@ func (s *apiRPCServer) SendPushNotification(args *Z_SendPushNotificationArgs, re } return nil } + +type Z_UpdateUserAuthArgs struct { + A string + B *model.UserAuth +} + +type Z_UpdateUserAuthReturns struct { + A *model.UserAuth + B *model.AppError +} + +func (g *apiRPCClient) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { + _args := &Z_UpdateUserAuthArgs{userID, userAuth} + _returns := &Z_UpdateUserAuthReturns{} + if err := g.client.Call("Plugin.UpdateUserAuth", _args, _returns); err != nil { + log.Printf("RPC call to UpdateUserAuth API failed: %s", err.Error()) + } + return _returns.A, _returns.B +} + +func (s *apiRPCServer) UpdateUserAuth(args *Z_UpdateUserAuthArgs, returns *Z_UpdateUserAuthReturns) error { + if hook, ok := s.impl.(interface { + UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) + }); ok { + returns.A, returns.B = hook.UpdateUserAuth(args.A, args.B) + } else { + return encodableError(fmt.Errorf("API UpdateUserAuth called but not implemented.")) + } + return nil +} diff --git a/server/public/plugin/plugintest/api.go b/server/public/plugin/plugintest/api.go index c820ee969b5..c09622a85c1 100644 --- a/server/public/plugin/plugintest/api.go +++ b/server/public/plugin/plugintest/api.go @@ -4099,6 +4099,34 @@ func (_m *API) UpdateUserActive(userID string, active bool) *model.AppError { return r0 } +// UpdateUserAuth provides a mock function with given fields: userID, userAuth +func (_m *API) UpdateUserAuth(userID string, userAuth *model.UserAuth) (*model.UserAuth, *model.AppError) { + ret := _m.Called(userID, userAuth) + + var r0 *model.UserAuth + var r1 *model.AppError + if rf, ok := ret.Get(0).(func(string, *model.UserAuth) (*model.UserAuth, *model.AppError)); ok { + return rf(userID, userAuth) + } + if rf, ok := ret.Get(0).(func(string, *model.UserAuth) *model.UserAuth); ok { + r0 = rf(userID, userAuth) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.UserAuth) + } + } + + if rf, ok := ret.Get(1).(func(string, *model.UserAuth) *model.AppError); ok { + r1 = rf(userID, userAuth) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + // UpdateUserCustomStatus provides a mock function with given fields: userID, customStatus func (_m *API) UpdateUserCustomStatus(userID string, customStatus *model.CustomStatus) *model.AppError { ret := _m.Called(userID, customStatus)