From 3ac6edb4064d83d683b93b8a334e3967948b1bb0 Mon Sep 17 00:00:00 2001 From: Utsav Ladani <55320836+Utsav-Ladani@users.noreply.github.com> Date: Mon, 29 Jan 2024 20:25:34 +0530 Subject: [PATCH] [GH-25484] Fix draft removal on post deletion (#25715) * [GH-25484] Fix draft removal on post deletion * [GH-25484] Add batch migration to remove orphan drafts * [GH-25484] Fix tests of migration and draft store * [GH-25484] Remove translation file changes. * [GH-25484] Remove translation file changes. --------- Co-authored-by: Devin Binnie <52460000+devinbinnie@users.noreply.github.com> Co-authored-by: Mattermost Build Co-authored-by: Harrison Healey --- server/channels/app/migrations.go | 22 + server/channels/app/post.go | 12 + server/channels/app/server.go | 6 + .../delete_orphan_drafts_migration.go | 85 ++++ .../delete_orphan_drafts_migration_test.go | 180 +++++++ .../opentracinglayer/opentracinglayer.go | 36 ++ .../channels/store/retrylayer/retrylayer.go | 42 ++ server/channels/store/sqlstore/draft_store.go | 79 +++ server/channels/store/store.go | 2 + .../channels/store/storetest/draft_store.go | 456 ++++++++++++++++++ .../store/storetest/mocks/DraftStore.go | 28 ++ .../channels/store/timerlayer/timerlayer.go | 32 ++ server/channels/testlib/store.go | 1 + server/public/model/job.go | 1 + server/public/model/migration.go | 1 + webapp/channels/src/actions/storage.ts | 11 +- .../channels/src/actions/views/drafts.test.ts | 4 +- webapp/channels/src/actions/views/drafts.ts | 40 +- .../src/actions/websocket_actions.jsx | 23 +- 19 files changed, 1046 insertions(+), 15 deletions(-) create mode 100644 server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration.go create mode 100644 server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration_test.go diff --git a/server/channels/app/migrations.go b/server/channels/app/migrations.go index edb19f2cbc8..213cac6f699 100644 --- a/server/channels/app/migrations.go +++ b/server/channels/app/migrations.go @@ -628,6 +628,27 @@ func (s *Server) doDeleteEmptyDraftsMigration(c request.CTX) { } } +func (s *Server) doDeleteOrphanDraftsMigration(c request.CTX) { + // If the migration is already marked as completed, don't do it again. + if _, err := s.Store().System().GetByName(model.MigrationKeyDeleteOrphanDrafts); err == nil { + return + } + + jobs, err := s.Store().Job().GetAllByTypeAndStatus(c, model.JobTypeDeleteOrphanDraftsMigration, model.JobStatusPending) + if err != nil { + mlog.Fatal("failed to get jobs by type and status", mlog.Err(err)) + return + } + if len(jobs) > 0 { + return + } + + if _, appErr := s.Jobs.CreateJobOnce(c, model.JobTypeDeleteOrphanDraftsMigration, nil); appErr != nil { + mlog.Fatal("failed to start job for deleting orphan drafts", mlog.Err(appErr)) + return + } +} + func (a *App) DoAppMigrations() { a.Srv().doAppMigrations() } @@ -654,4 +675,5 @@ func (s *Server) doAppMigrations() { s.doElasticsearchFixChannelIndex(c) s.doCloudS3PathMigrations(c) s.doDeleteEmptyDraftsMigration(c) + s.doDeleteOrphanDraftsMigration(c) } diff --git a/server/channels/app/post.go b/server/channels/app/post.go index 420e305b6c6..6353c590723 100644 --- a/server/channels/app/post.go +++ b/server/channels/app/post.go @@ -1417,11 +1417,23 @@ func (a *App) DeletePost(c request.CTX, postID, deleteByID string) (*model.Post, } }) + // delete drafts associated with the post when deleting the post + a.Srv().Go(func() { + a.deleteDraftsAssociatedWithPost(c, channel, post) + }) + a.invalidateCacheForChannelPosts(post.ChannelId) return post, nil } +func (a *App) deleteDraftsAssociatedWithPost(c request.CTX, channel *model.Channel, post *model.Post) { + if err := a.Srv().Store().Draft().DeleteDraftsAssociatedWithPost(channel.Id, post.Id); err != nil { + c.Logger().Error("Failed to delete drafts associated with post when deleting post", mlog.Err(err)) + return + } +} + func (a *App) deleteFlaggedPosts(c request.CTX, postID string) { if err := a.Srv().Store().Preference().DeleteCategoryAndName(model.PreferenceCategoryFlaggedPost, postID); err != nil { c.Logger().Warn("Unable to delete flagged post preference when deleting post.", mlog.Err(err)) diff --git a/server/channels/app/server.go b/server/channels/app/server.go index be4dd2c9a6c..bae76a4f24b 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -41,6 +41,7 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/jobs/active_users" "github.com/mattermost/mattermost/server/v8/channels/jobs/cleanup_desktop_tokens" "github.com/mattermost/mattermost/server/v8/channels/jobs/delete_empty_drafts_migration" + "github.com/mattermost/mattermost/server/v8/channels/jobs/delete_orphan_drafts_migration" "github.com/mattermost/mattermost/server/v8/channels/jobs/expirynotify" "github.com/mattermost/mattermost/server/v8/channels/jobs/export_delete" "github.com/mattermost/mattermost/server/v8/channels/jobs/export_process" @@ -1589,6 +1590,11 @@ func (s *Server) initJobs() { delete_empty_drafts_migration.MakeWorker(s.Jobs, s.Store(), New(ServerConnector(s.Channels()))), nil) + s.Jobs.RegisterJobType( + model.JobTypeDeleteOrphanDraftsMigration, + delete_orphan_drafts_migration.MakeWorker(s.Jobs, s.Store(), New(ServerConnector(s.Channels()))), + nil) + s.Jobs.RegisterJobType( model.JobTypeExportDelete, export_delete.MakeWorker(s.Jobs, New(ServerConnector(s.Channels()))), diff --git a/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration.go b/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration.go new file mode 100644 index 00000000000..6c4c20820c7 --- /dev/null +++ b/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration.go @@ -0,0 +1,85 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package delete_orphan_drafts_migration + +import ( + "strconv" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/jobs" + "github.com/mattermost/mattermost/server/v8/channels/store" + "github.com/pkg/errors" +) + +const ( + timeBetweenBatches = 1 * time.Second +) + +// MakeWorker creates a batch migration worker to delete empty drafts. +func MakeWorker(jobServer *jobs.JobServer, store store.Store, app jobs.BatchMigrationWorkerAppIFace) model.Worker { + return jobs.MakeBatchMigrationWorker( + jobServer, + store, + app, + model.MigrationKeyDeleteOrphanDrafts, + timeBetweenBatches, + doDeleteOrphanDraftsMigrationBatch, + ) +} + +// parseJobMetadata parses the opaque job metadata to return the information needed to decide which +// batch to process next. +func parseJobMetadata(data model.StringMap) (int64, string, error) { + createAt := int64(0) + if data["create_at"] != "" { + parsedCreateAt, parseErr := strconv.ParseInt(data["create_at"], 10, 64) + if parseErr != nil { + return 0, "", errors.Wrap(parseErr, "failed to parse create_at") + } + createAt = parsedCreateAt + } + + userID := data["user_id"] + + return createAt, userID, nil +} + +// makeJobMetadata encodes the information needed to decide which batch to process next back into +// the opaque job metadata. +func makeJobMetadata(createAt int64, userID string) model.StringMap { + data := make(model.StringMap) + data["create_at"] = strconv.FormatInt(createAt, 10) + data["user_id"] = userID + + return data +} + +// doDeleteOrphanDraftsMigrationBatch iterates through all drafts, deleting orphan drafts within each +// batch keyed by the compound primary key (createAt, userID) +func doDeleteOrphanDraftsMigrationBatch(data model.StringMap, store store.Store) (model.StringMap, bool, error) { + createAt, userID, err := parseJobMetadata(data) + if err != nil { + return nil, false, errors.Wrap(err, "failed to parse job metadata") + } + + // Determine the /next/ (createAt, userId) by finding the last record in the batch we're + // about to delete. + nextCreateAt, nextUserID, err := store.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userID) + if err != nil { + return nil, false, errors.Wrapf(err, "failed to get the next batch (create_at=%v, user_id=%v)", createAt, userID) + } + + // If we get the nil values, it means the batch was empty and we're done. + if nextCreateAt == 0 && nextUserID == "" { + return nil, true, nil + } + + err = store.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userID) + if err != nil { + return nil, false, errors.Wrapf(err, "failed to delete orphan drafts (create_at=%v, user_id=%v)", createAt, userID) + } + + return makeJobMetadata(nextCreateAt, nextUserID), false, nil +} diff --git a/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration_test.go b/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration_test.go new file mode 100644 index 00000000000..e32475b4d62 --- /dev/null +++ b/server/channels/jobs/delete_orphan_drafts_migration/delete_orphan_drafts_migration_test.go @@ -0,0 +1,180 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package delete_orphan_drafts_migration + +import ( + "errors" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/store/storetest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJobMetadata(t *testing.T) { + t.Run("parse nil data", func(t *testing.T) { + var data model.StringMap + createAt, userID, err := parseJobMetadata(data) + require.NoError(t, err) + assert.Empty(t, createAt) + assert.Empty(t, userID) + }) + + t.Run("parse invalid create_at", func(t *testing.T) { + data := make(model.StringMap) + data["user_id"] = "user_id" + data["create_at"] = "invalid" + _, _, err := parseJobMetadata(data) + require.Error(t, err) + }) + + t.Run("parse valid", func(t *testing.T) { + data := make(model.StringMap) + data["user_id"] = "user_id" + data["create_at"] = "1695918431" + + createAt, userID, err := parseJobMetadata(data) + require.NoError(t, err) + assert.EqualValues(t, 1695918431, createAt) + assert.Equal(t, "user_id", userID) + }) + + t.Run("parse/make", func(t *testing.T) { + data := makeJobMetadata(1695918431, "user_id") + assert.Equal(t, "1695918431", data["create_at"]) + assert.Equal(t, "user_id", data["user_id"]) + + createAt, userID, err := parseJobMetadata(data) + require.NoError(t, err) + assert.EqualValues(t, 1695918431, createAt) + assert.Equal(t, "user_id", userID) + }) +} + +func TestDoDeleteOrphanDraftsMigrationBatch(t *testing.T) { + t.Run("invalid job metadata", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + data := make(model.StringMap) + data["user_id"] = "user_id" + data["create_at"] = "invalid" + data, done, err := doDeleteOrphanDraftsMigrationBatch(data, mockStore) + require.Error(t, err) + assert.False(t, done) + assert.Nil(t, data) + }) + + t.Run("failure getting next offset", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(1695920000), "user_id_1" + nextCreateAt, nextUserID := int64(0), "" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, errors.New("failure")) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(makeJobMetadata(createAt, userID), mockStore) + require.EqualError(t, err, "failed to get the next batch (create_at=1695920000, user_id=user_id_1): failure") + assert.False(t, done) + assert.Nil(t, data) + }) + + t.Run("failure deleting batch", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(1695920000), "user_id_1" + nextCreateAt, nextUserID := int64(1695922034), "user_id_2" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, nil) + mockStore.DraftStore.On("DeleteOrphanDraftsByCreateAtAndUserId", createAt, userID).Return(errors.New("failure")) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(makeJobMetadata(createAt, userID), mockStore) + require.EqualError(t, err, "failed to delete orphan drafts (create_at=1695920000, user_id=user_id_1): failure") + assert.False(t, done) + assert.Nil(t, data) + }) + + t.Run("do first batch (nil job metadata)", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(0), "" + nextCreateAt, nextUserID := int64(1695922034), "user_id_2" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, nil) + mockStore.DraftStore.On("DeleteOrphanDraftsByCreateAtAndUserId", createAt, userID).Return(nil) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(nil, mockStore) + require.NoError(t, err) + assert.False(t, done) + assert.Equal(t, model.StringMap{ + "create_at": "1695922034", + "user_id": "user_id_2", + }, data) + }) + + t.Run("do first batch (empty job metadata)", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(0), "" + nextCreateAt, nextUserID := int64(1695922034), "user_id_2" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, nil) + mockStore.DraftStore.On("DeleteOrphanDraftsByCreateAtAndUserId", createAt, userID).Return(nil) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(model.StringMap{}, mockStore) + require.NoError(t, err) + assert.False(t, done) + assert.Equal(t, makeJobMetadata(nextCreateAt, nextUserID), data) + }) + + t.Run("do batch", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(1695922000), "user_id_1" + nextCreateAt, nextUserID := int64(1695922034), "user_id_2" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, nil) + mockStore.DraftStore.On("DeleteOrphanDraftsByCreateAtAndUserId", createAt, userID).Return(nil) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(makeJobMetadata(createAt, userID), mockStore) + require.NoError(t, err) + assert.False(t, done) + assert.Equal(t, makeJobMetadata(nextCreateAt, nextUserID), data) + }) + + t.Run("done batches", func(t *testing.T) { + mockStore := &storetest.Store{} + t.Cleanup(func() { + mockStore.AssertExpectations(t) + }) + + createAt, userID := int64(1695922000), "user_id_1" + nextCreateAt, nextUserID := int64(0), "" + + mockStore.DraftStore.On("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", createAt, userID).Return(nextCreateAt, nextUserID, nil) + + data, done, err := doDeleteOrphanDraftsMigrationBatch(makeJobMetadata(createAt, userID), mockStore) + require.NoError(t, err) + assert.True(t, done) + assert.Nil(t, data) + }) +} diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go index d8c3b82b73f..b0d9bdda2b9 100644 --- a/server/channels/store/opentracinglayer/opentracinglayer.go +++ b/server/channels/store/opentracinglayer/opentracinglayer.go @@ -3381,6 +3381,24 @@ func (s *OpenTracingLayerDraftStore) Delete(userID string, channelID string, roo return err } +func (s *OpenTracingLayerDraftStore) DeleteDraftsAssociatedWithPost(channelID string, rootID string) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.DeleteDraftsAssociatedWithPost") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.DraftStore.DeleteDraftsAssociatedWithPost(channelID, rootID) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + func (s *OpenTracingLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userId string) error { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.DeleteEmptyDraftsByCreateAtAndUserId") @@ -3399,6 +3417,24 @@ func (s *OpenTracingLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(create return err } +func (s *OpenTracingLayerDraftStore) DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.DeleteOrphanDraftsByCreateAtAndUserId") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + err := s.DraftStore.DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return err +} + func (s *OpenTracingLayerDraftStore) Get(userID string, channelID string, rootID string, includeDeleted bool) (*model.Draft, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.Get") diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 498de578c8a..572d34cdfc0 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -3776,6 +3776,27 @@ func (s *RetryLayerDraftStore) Delete(userID string, channelID string, rootID st } +func (s *RetryLayerDraftStore) DeleteDraftsAssociatedWithPost(channelID string, rootID string) error { + + tries := 0 + for { + err := s.DraftStore.DeleteDraftsAssociatedWithPost(channelID, rootID) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userId string) error { tries := 0 @@ -3797,6 +3818,27 @@ func (s *RetryLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int } +func (s *RetryLayerDraftStore) DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error { + + tries := 0 + for { + err := s.DraftStore.DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerDraftStore) Get(userID string, channelID string, rootID string, includeDeleted bool) (*model.Draft, error) { tries := 0 diff --git a/server/channels/store/sqlstore/draft_store.go b/server/channels/store/sqlstore/draft_store.go index e5ae58c41e0..1ef09f3a245 100644 --- a/server/channels/store/sqlstore/draft_store.go +++ b/server/channels/store/sqlstore/draft_store.go @@ -181,6 +181,29 @@ func (s *SqlDraftStore) Delete(userID, channelID, rootID string) error { return nil } +// DeleteDraftsAssociatedWithPost deletes all drafts associated with a post. +func (s *SqlDraftStore) DeleteDraftsAssociatedWithPost(channelID, rootID string) error { + query := s.getQueryBuilder(). + Delete("Drafts"). + Where(sq.Eq{ + "ChannelId": channelID, + "RootId": rootID, + }) + + sql, args, err := query.ToSql() + if err != nil { + return errors.Wrapf(err, "failed to convert to sql") + } + + _, err = s.GetMasterX().Exec(sql, args...) + + if err != nil { + return errors.Wrap(err, "failed to delete Draft") + } + + return nil +} + // GetMaxDraftSize returns the maximum number of runes that may be stored in a post. func (s *SqlDraftStore) GetMaxDraftSize() int { s.maxDraftSizeOnce.Do(func() { @@ -320,3 +343,59 @@ func (s *SqlDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, use return nil } + +func (s *SqlDraftStore) DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error { + var builder Builder + if s.DriverName() == model.DatabaseDriverPostgres { + builder = s.getQueryBuilder(). + Delete("Drafts d"). + PrefixExpr(s.getQueryBuilder().Select(). + Prefix("WITH dd AS ("). + Columns("UserId", "ChannelId", "RootId"). + From("Drafts"). + Where(sq.Or{ + sq.Gt{"CreateAt": createAt}, + sq.And{ + sq.Eq{"CreateAt": createAt}, + sq.Gt{"UserId": userId}, + }, + }). + OrderBy("CreateAt", "UserId"). + Limit(100). + Suffix(")"), + ). + Using("dd"). + Where("d.UserId = dd.UserId"). + Where("d.ChannelId = dd.ChannelId"). + Where("d.RootId = dd.RootId"). + Suffix("AND (d.RootId IN (SELECT Id FROM Posts WHERE DeleteAt <> 0) OR NOT EXISTS (SELECT 1 FROM Posts WHERE Posts.Id = d.RootId))") + } else if s.DriverName() == model.DatabaseDriverMysql { + builder = s.getQueryBuilder(). + Delete("Drafts d"). + What("d.*"). + JoinClause(s.getQueryBuilder().Select(). + Prefix("INNER JOIN ("). + Columns("UserId, ChannelId, RootId"). + From("Drafts"). + Where(sq.And{ + sq.Or{ + sq.Gt{"CreateAt": createAt}, + sq.And{ + sq.Eq{"CreateAt": createAt}, + sq.Gt{"UserId": userId}, + }, + }, + }). + OrderBy("CreateAt", "UserId"). + Limit(100). + Suffix(") dj ON (d.UserId = dj.UserId AND d.ChannelId = dj.ChannelId AND d.RootId = dj.RootId)"), + ). + Suffix("AND (d.RootId IN (SELECT Id FROM Posts WHERE DeleteAt <> 0) OR NOT EXISTS (SELECT 1 FROM Posts WHERE Posts.Id = d.RootId))") + } + + if _, err := s.GetMasterX().ExecBuilder(builder); err != nil { + return errors.Wrapf(err, "failed to delete orphan drafts") + } + + return nil +} diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 2749bf714ae..dffce837eea 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -1005,9 +1005,11 @@ type DraftStore interface { Upsert(d *model.Draft) (*model.Draft, error) Get(userID, channelID, rootID string, includeDeleted bool) (*model.Draft, error) Delete(userID, channelID, rootID string) error + DeleteDraftsAssociatedWithPost(channelID, rootID string) error GetDraftsForUser(userID, teamID string) ([]*model.Draft, error) GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt int64, userId string) (int64, string, error) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userId string) error + DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error } type PostAcknowledgementStore interface { diff --git a/server/channels/store/storetest/draft_store.go b/server/channels/store/storetest/draft_store.go index 3b38d9a34c0..83ff14b7e9d 100644 --- a/server/channels/store/storetest/draft_store.go +++ b/server/channels/store/storetest/draft_store.go @@ -19,10 +19,12 @@ func TestDraftStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) t.Run("SaveDraft", func(t *testing.T) { testSaveDraft(t, rctx, ss) }) t.Run("UpdateDraft", func(t *testing.T) { testUpdateDraft(t, rctx, ss) }) t.Run("DeleteDraft", func(t *testing.T) { testDeleteDraft(t, rctx, ss) }) + t.Run("DeleteDraftsAssociatedWithPost", func(t *testing.T) { testDeleteDraftsAssociatedWithPost(t, rctx, ss) }) t.Run("GetDraft", func(t *testing.T) { testGetDraft(t, rctx, ss) }) t.Run("GetDraftsForUser", func(t *testing.T) { testGetDraftsForUser(t, rctx, ss) }) t.Run("GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration", func(t *testing.T) { testGetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(t, rctx, ss) }) t.Run("DeleteEmptyDraftsByCreateAtAndUserId", func(t *testing.T) { testDeleteEmptyDraftsByCreateAtAndUserId(t, rctx, ss) }) + t.Run("DeleteOrphanDraftsByCreateAtAndUserId", func(t *testing.T) { testDeleteOrphanDraftsByCreateAtAndUserId(t, rctx, ss) }) } func testSaveDraft(t *testing.T, rctx request.CTX, ss store.Store) { @@ -371,6 +373,16 @@ func makeDrafts(t *testing.T, ss store.Store, count int, message string) { } } +func countDrafts(t *testing.T, rctx request.CTX, ss store.Store) int { + t.Helper() + + var count int + err := ss.GetInternalMasterDB().QueryRow("SELECT COUNT(*) FROM Drafts").Scan(&count) + require.NoError(t, err) + + return count +} + func countDraftPages(t *testing.T, rctx request.CTX, ss store.Store) int { t.Helper() @@ -401,6 +413,76 @@ func countDraftPages(t *testing.T, rctx request.CTX, ss store.Store) int { return pages } +func clearPosts(t *testing.T, rctx request.CTX, ss store.Store) { + t.Helper() + + _, err := ss.GetInternalMasterDB().Exec("DELETE FROM Posts") + require.NoError(t, err) +} + +func makeDraftsWithNonDeletedPosts(t *testing.T, ss store.Store, count int, message string) { + t.Helper() + + for i := 1; i <= count; i++ { + post, err := ss.Post().Save(&model.Post{ + CreateAt: model.GetMillis(), + UpdateAt: model.GetMillis(), + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: message, + }) + require.NoError(t, err) + + _, err = ss.Draft().Upsert(&model.Draft{ + CreateAt: model.GetMillis(), + UpdateAt: model.GetMillis(), + UserId: post.UserId, + ChannelId: post.ChannelId, + RootId: post.Id, + Message: message, + }) + require.NoError(t, err) + + if i%100 == 0 { + time.Sleep(5 * time.Millisecond) + } + } + + time.Sleep(5 * time.Millisecond) +} + +func makeDraftsWithDeletedPosts(t *testing.T, ss store.Store, count int, message string) { + t.Helper() + + for i := 1; i <= count; i++ { + post, err := ss.Post().Save(&model.Post{ + CreateAt: model.GetMillis(), + UpdateAt: model.GetMillis(), + DeleteAt: model.GetMillis(), + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: message, + }) + require.NoError(t, err) + + _, err = ss.Draft().Upsert(&model.Draft{ + CreateAt: model.GetMillis(), + UpdateAt: model.GetMillis(), + UserId: post.UserId, + ChannelId: post.ChannelId, + RootId: post.Id, + Message: message, + }) + require.NoError(t, err) + + if i%100 == 0 { + time.Sleep(5 * time.Millisecond) + } + } + + time.Sleep(5 * time.Millisecond) +} + func testGetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("no drafts", func(t *testing.T) { clearDrafts(t, rctx, ss) @@ -560,3 +642,377 @@ func testDeleteEmptyDraftsByCreateAtAndUserId(t *testing.T, rctx request.CTX, ss assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") }) } + +func testDeleteOrphanDraftsByCreateAtAndUserId(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("nil parameters", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + err := ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(0, "") + require.NoError(t, err) + }) + + t.Run("delete single page, drafts with no post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDrafts(t, ss, 100, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 0, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + t.Run("delete multiple pages, drafts with no post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDrafts(t, ss, 300, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 2, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 1, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 0, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + t.Run("delete single page, drafts with deleted post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDraftsWithDeletedPosts(t, ss, 100, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 0, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + t.Run("delete multiple pages, drafts with deleted post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDraftsWithDeletedPosts(t, ss, 300, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 2, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 1, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 0, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + t.Run("delete single page, drafts with non deleted post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDraftsWithNonDeletedPosts(t, ss, 100, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 100, countDrafts(t, rctx, ss), "incorrect number of drafts") + assert.Equal(t, 1, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + t.Run("delete multiple pages, drafts with non deleted post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + makeDraftsWithNonDeletedPosts(t, ss, 300, "Okay") + + createAt, userId := int64(0), "" + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 300, countDrafts(t, rctx, ss), "incorrect number of drafts") + assert.Equal(t, 3, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 300, countDrafts(t, rctx, ss), "incorrect number of drafts") + assert.Equal(t, 3, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + assert.Equal(t, 300, countDrafts(t, rctx, ss), "incorrect number of drafts") + assert.Equal(t, 3, countDraftPages(t, rctx, ss), "incorrect number of pages") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) + + // This test is a bit more complicated, but it's the most realistic scenario and covers all the remaining cases + t.Run("delete multiple pages, some drafts with deleted post, some with non deleted post, and some with no post", func(t *testing.T) { + clearDrafts(t, rctx, ss) + clearPosts(t, rctx, ss) + + // 50 drafts will be deleted from this page + makeDrafts(t, ss, 50, "Yup") + makeDraftsWithNonDeletedPosts(t, ss, 50, "Okay") + + // 100 drafts will be deleted from this page + makeDrafts(t, ss, 50, "Yup") + makeDraftsWithDeletedPosts(t, ss, 50, "Okay") + + // 50 drafts will be deleted from this page + makeDraftsWithDeletedPosts(t, ss, 50, "Okay") + makeDraftsWithNonDeletedPosts(t, ss, 50, "Okay") + + // 70 drafts will be deleted from this page + makeDrafts(t, ss, 40, "Yup") + makeDraftsWithDeletedPosts(t, ss, 30, "Okay") + makeDraftsWithNonDeletedPosts(t, ss, 30, "Okay") + + // No drafts will be deleted from this page + makeDraftsWithNonDeletedPosts(t, ss, 100, "Okay") + + // Verify initially 5 pages with 500 drafts + assert.Equal(t, 5, countDraftPages(t, rctx, ss), "incorrect number of pages") + assert.Equal(t, 500, countDrafts(t, rctx, ss), "incorrect number of drafts") + + createAt, userId := int64(0), "" + + nextCreateAt, nextUserId, err := ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + // Only deleted 50, so still 5 pages + assert.Equal(t, 5, countDraftPages(t, rctx, ss), "incorrect number of pages") + assert.Equal(t, 450, countDrafts(t, rctx, ss), "incorrect number of drafts") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + // Now deleted 150, so down to 4 pages + assert.Equal(t, 4, countDraftPages(t, rctx, ss), "incorrect number of pages") + assert.Equal(t, 350, countDrafts(t, rctx, ss), "incorrect number of drafts") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + // Now deleted 200 now, so down to 3 pages + assert.Equal(t, 3, countDraftPages(t, rctx, ss), "incorrect number of pages") + assert.Equal(t, 300, countDrafts(t, rctx, ss), "incorrect number of drafts") + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + // Now deleted 270 empty messages, so still 3 pages + assert.Equal(t, 3, countDraftPages(t, rctx, ss), "incorrect number of pages") + assert.Equal(t, 230, countDrafts(t, rctx, ss), "incorrect number of drafts") + + // Keep going through all pages to verify nothing else gets deleted. + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + err = ss.Draft().DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + require.NoError(t, err) + createAt, userId = nextCreateAt, nextUserId + + // Verify we're done iterating + + nextCreateAt, nextUserId, err = ss.Draft().GetLastCreateAtAndUserIdValuesForEmptyDraftsMigration(createAt, userId) + require.NoError(t, err) + assert.EqualValues(t, 0, nextCreateAt, "should have finished iterating through drafts") + assert.Equal(t, "", nextUserId, "should have finished iterating through drafts") + }) +} + +func testDeleteDraftsAssociatedWithPost(t *testing.T, rctx request.CTX, ss store.Store) { + user1 := &model.User{ + Id: model.NewId(), + } + + user2 := &model.User{ + Id: model.NewId(), + } + + channel1 := &model.Channel{ + Id: model.NewId(), + } + + channel2 := &model.Channel{ + Id: model.NewId(), + } + + _, err := ss.Channel().SaveMember(&model.ChannelMember{ + ChannelId: channel1.Id, + UserId: user1.Id, + NotifyProps: model.GetDefaultChannelNotifyProps(), + }) + require.NoError(t, err) + + _, err = ss.Channel().SaveMember(&model.ChannelMember{ + ChannelId: channel2.Id, + UserId: user2.Id, + NotifyProps: model.GetDefaultChannelNotifyProps(), + }) + require.NoError(t, err) + + post1, err := ss.Post().Save(&model.Post{ + UserId: user1.Id, + ChannelId: channel1.Id, + Message: "post1", + }) + require.NoError(t, err) + + post2, err := ss.Post().Save(&model.Post{ + UserId: user2.Id, + ChannelId: channel2.Id, + Message: "post2", + }) + require.NoError(t, err) + + _, err = ss.Draft().Upsert(&model.Draft{ + UserId: user1.Id, + ChannelId: channel1.Id, + RootId: post1.Id, + Message: "draft1", + }) + require.NoError(t, err) + + _, err = ss.Draft().Upsert(&model.Draft{ + UserId: user2.Id, + ChannelId: channel1.Id, + RootId: post1.Id, + Message: "draft2", + }) + require.NoError(t, err) + + draft3, err := ss.Draft().Upsert(&model.Draft{ + UserId: user1.Id, + ChannelId: channel2.Id, + RootId: post2.Id, + Message: "draft3", + }) + require.NoError(t, err) + + draft4, err := ss.Draft().Upsert(&model.Draft{ + UserId: user2.Id, + ChannelId: channel2.Id, + RootId: post2.Id, + Message: "draft4", + }) + require.NoError(t, err) + + t.Run("delete drafts associated with post", func(t *testing.T) { + err = ss.Draft().DeleteDraftsAssociatedWithPost(channel1.Id, post1.Id) + require.NoError(t, err) + + _, err = ss.Draft().Get(user1.Id, channel1.Id, post1.Id, false) + require.Error(t, err) + assert.IsType(t, &store.ErrNotFound{}, err) + + _, err = ss.Draft().Get(user2.Id, channel1.Id, post1.Id, false) + require.Error(t, err) + assert.IsType(t, &store.ErrNotFound{}, err) + + draft, err := ss.Draft().Get(user1.Id, channel2.Id, post2.Id, false) + require.NoError(t, err) + assert.Equal(t, draft3.Message, draft.Message) + + draft, err = ss.Draft().Get(user2.Id, channel2.Id, post2.Id, false) + require.NoError(t, err) + assert.Equal(t, draft4.Message, draft.Message) + }) +} diff --git a/server/channels/store/storetest/mocks/DraftStore.go b/server/channels/store/storetest/mocks/DraftStore.go index cd9fda065c8..e6890f88294 100644 --- a/server/channels/store/storetest/mocks/DraftStore.go +++ b/server/channels/store/storetest/mocks/DraftStore.go @@ -28,6 +28,20 @@ func (_m *DraftStore) Delete(userID string, channelID string, rootID string) err return r0 } +// DeleteDraftsAssociatedWithPost provides a mock function with given fields: channelID, rootID +func (_m *DraftStore) DeleteDraftsAssociatedWithPost(channelID string, rootID string) error { + ret := _m.Called(channelID, rootID) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(channelID, rootID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteEmptyDraftsByCreateAtAndUserId provides a mock function with given fields: createAt, userId func (_m *DraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userId string) error { ret := _m.Called(createAt, userId) @@ -42,6 +56,20 @@ func (_m *DraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userI return r0 } +// DeleteOrphanDraftsByCreateAtAndUserId provides a mock function with given fields: createAt, userId +func (_m *DraftStore) DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error { + ret := _m.Called(createAt, userId) + + var r0 error + if rf, ok := ret.Get(0).(func(int64, string) error); ok { + r0 = rf(createAt, userId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Get provides a mock function with given fields: userID, channelID, rootID, includeDeleted func (_m *DraftStore) Get(userID string, channelID string, rootID string, includeDeleted bool) (*model.Draft, error) { ret := _m.Called(userID, channelID, rootID, includeDeleted) diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index fa069c99ba6..10bc79bcf1b 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -3101,6 +3101,22 @@ func (s *TimerLayerDraftStore) Delete(userID string, channelID string, rootID st return err } +func (s *TimerLayerDraftStore) DeleteDraftsAssociatedWithPost(channelID string, rootID string) error { + start := time.Now() + + err := s.DraftStore.DeleteDraftsAssociatedWithPost(channelID, rootID) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("DraftStore.DeleteDraftsAssociatedWithPost", success, elapsed) + } + return err +} + func (s *TimerLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int64, userId string) error { start := time.Now() @@ -3117,6 +3133,22 @@ func (s *TimerLayerDraftStore) DeleteEmptyDraftsByCreateAtAndUserId(createAt int return err } +func (s *TimerLayerDraftStore) DeleteOrphanDraftsByCreateAtAndUserId(createAt int64, userId string) error { + start := time.Now() + + err := s.DraftStore.DeleteOrphanDraftsByCreateAtAndUserId(createAt, userId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("DraftStore.DeleteOrphanDraftsByCreateAtAndUserId", success, elapsed) + } + return err +} + func (s *TimerLayerDraftStore) Get(userID string, channelID string, rootID string, includeDeleted bool) (*model.Draft, error) { start := time.Now() diff --git a/server/channels/testlib/store.go b/server/channels/testlib/store.go index deea701b167..30766a67cb7 100644 --- a/server/channels/testlib/store.go +++ b/server/channels/testlib/store.go @@ -72,6 +72,7 @@ func GetMockStoreForSetupFunctions() *mocks.Store { systemStore.On("GetByName", model.MigrationKeyAddCustomUserGroupsPermissionRestore).Return(&model.System{Name: model.MigrationKeyAddCustomUserGroupsPermissionRestore, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddReadChannelContentPermissions).Return(&model.System{Name: model.MigrationKeyAddReadChannelContentPermissions, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyDeleteEmptyDrafts).Return(&model.System{Name: model.MigrationKeyDeleteEmptyDrafts, Value: "true"}, nil) + systemStore.On("GetByName", model.MigrationKeyDeleteOrphanDrafts).Return(&model.System{Name: model.MigrationKeyDeleteOrphanDrafts, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddIPFilteringPermissions).Return(&model.System{Name: model.MigrationKeyAddIPFilteringPermissions, Value: "true"}, nil) systemStore.On("GetByName", "CustomGroupAdminRoleCreationMigrationComplete").Return(&model.System{Name: model.MigrationKeyAddPlayboosksManageRolesPermissions, Value: "true"}, nil) systemStore.On("GetByName", "products_boards").Return(&model.System{Name: "products_boards", Value: "true"}, nil) diff --git a/server/public/model/job.go b/server/public/model/job.go index 9c91cd8d808..ca8e7e367e6 100644 --- a/server/public/model/job.go +++ b/server/public/model/job.go @@ -38,6 +38,7 @@ const ( JobTypeCleanupDesktopTokens = "cleanup_desktop_tokens" JobTypeDeleteEmptyDraftsMigration = "delete_empty_drafts_migration" JobTypeRefreshPostStats = "refresh_post_stats" + JobTypeDeleteOrphanDraftsMigration = "delete_orphan_drafts_migration" JobTypeExportUsersToCSV = "export_users_to_csv" JobStatusPending = "pending" diff --git a/server/public/model/migration.go b/server/public/model/migration.go index 12692e6f83d..817e49040fa 100644 --- a/server/public/model/migration.go +++ b/server/public/model/migration.go @@ -44,5 +44,6 @@ const ( MigrationKeyElasticsearchFixChannelIndex = "elasticsearch_fix_channel_index_migration" MigrationKeyS3Path = "s3_path_migration" MigrationKeyDeleteEmptyDrafts = "delete_empty_drafts_migration" + MigrationKeyDeleteOrphanDrafts = "delete_orphan_drafts_migration" MigrationKeyAddIPFilteringPermissions = "add_ip_filtering_permissions" ) diff --git a/webapp/channels/src/actions/storage.ts b/webapp/channels/src/actions/storage.ts index 34b91fbb9db..a72a082f1be 100644 --- a/webapp/channels/src/actions/storage.ts +++ b/webapp/channels/src/actions/storage.ts @@ -37,13 +37,10 @@ export function setGlobalItem(name: string, value: any) { }; } -export function removeGlobalItem(name: string): NewActionFunc { - return (dispatch) => { - dispatch({ - type: StorageTypes.REMOVE_GLOBAL_ITEM, - data: {name}, - }); - return {data: true}; +export function removeGlobalItem(name: string) { + return { + type: StorageTypes.REMOVE_GLOBAL_ITEM, + data: {name}, }; } diff --git a/webapp/channels/src/actions/views/drafts.test.ts b/webapp/channels/src/actions/views/drafts.test.ts index da03b540cf5..361e92799bc 100644 --- a/webapp/channels/src/actions/views/drafts.test.ts +++ b/webapp/channels/src/actions/views/drafts.test.ts @@ -5,7 +5,7 @@ import {Client4} from 'mattermost-redux/client'; import {Posts, Preferences} from 'mattermost-redux/constants'; import {getPreferenceKey} from 'mattermost-redux/utils/preference_utils'; -import {setGlobalItem} from 'actions/storage'; +import {removeGlobalItem, setGlobalItem} from 'actions/storage'; import mockStore from 'tests/test_store'; import {StoragePrefixes} from 'utils/constants'; @@ -174,6 +174,8 @@ describe('draft actions', () => { uploadsInProgress: [], })); + testStore.dispatch(removeGlobalItem(StoragePrefixes.DRAFT + channelId)); + expect(store.getActions()).toEqual(testStore.getActions()); }); diff --git a/webapp/channels/src/actions/views/drafts.ts b/webapp/channels/src/actions/views/drafts.ts index 6af9fd5e5d7..e730d6ecf93 100644 --- a/webapp/channels/src/actions/views/drafts.ts +++ b/webapp/channels/src/actions/views/drafts.ts @@ -9,6 +9,7 @@ import type {PostMetadata, PostPriorityMetadata} from '@mattermost/types/posts'; import type {PreferenceType} from '@mattermost/types/preferences'; import type {UserProfile} from '@mattermost/types/users'; +import {getPost} from 'mattermost-redux/actions/posts'; import {savePreferences} from 'mattermost-redux/actions/preferences'; import {Client4} from 'mattermost-redux/client'; import Preferences from 'mattermost-redux/constants/preferences'; @@ -16,7 +17,7 @@ import {syncedDraftsAreAllowedAndEnabled} from 'mattermost-redux/selectors/entit import {getCurrentUserId} from 'mattermost-redux/selectors/entities/users'; import type {NewActionFunc, NewActionFuncAsync} from 'mattermost-redux/types/actions'; -import {setGlobalItem} from 'actions/storage'; +import {removeGlobalItem, setGlobalItem} from 'actions/storage'; import {makeGetDrafts} from 'selectors/drafts'; import {getConnectionId} from 'selectors/general'; import {getGlobalItem} from 'selectors/storage'; @@ -44,13 +45,40 @@ export function getDrafts(teamId: string): NewActionFuncAsync transformServerDraft(draft)); + const response = await Client4.getUserDrafts(teamId); + + // check if response is an array + if (Array.isArray(response)) { + serverDrafts = response.map((draft) => transformServerDraft(draft)); + } } catch (error) { return {data: false, error}; } + const drafts = [...serverDrafts]; const localDrafts = getLocalDrafts(state); - const drafts = [...serverDrafts, ...localDrafts]; + + // drafts that are not on server, but on local storage + const localOnlyDrafts = localDrafts.filter((localDraft) => { + return !serverDrafts.find((serverDraft) => serverDraft.key === localDraft.key); + }); + + // check if drafts are still valid + await Promise.all(localOnlyDrafts.map(async (draft) => { + if (draft.value.rootId) { + // get post from server to check if it exists + const {error} = await dispatch(getPost(draft.value.rootId)); + + // remove locally stored draft if post does not exist + if (error.status_code === 404) { + await dispatch(setGlobalItem(draft.key, {message: '', fileInfos: [], uploadsInProgress: []})); + await dispatch(removeGlobalItem(draft.key)); + return; + } + } + + drafts.push(draft); + })); // Reconcile drafts and only keep the latest version of a draft. const draftsMap = new Map(drafts.map((draft) => [draft.key, draft])); @@ -74,7 +102,11 @@ export function removeDraft(key: string, channelId: string, rootId = ''): NewAct return async (dispatch, getState) => { const state = getState(); - dispatch(setGlobalItem(key, {message: '', fileInfos: [], uploadsInProgress: []})); + // set draft to empty to re-render the component + await dispatch(setGlobalItem(key, {message: '', fileInfos: [], uploadsInProgress: []})); + + // remove draft from storage + await dispatch(removeGlobalItem(key)); if (syncedDraftsAreAllowedAndEnabled(state)) { const connectionId = getConnectionId(getState()); diff --git a/webapp/channels/src/actions/websocket_actions.jsx b/webapp/channels/src/actions/websocket_actions.jsx index de1e1b701b7..9abcf1d47c0 100644 --- a/webapp/channels/src/actions/websocket_actions.jsx +++ b/webapp/channels/src/actions/websocket_actions.jsx @@ -100,7 +100,7 @@ import {redirectUserToDefaultTeam} from 'actions/global_actions'; import {sendDesktopNotification} from 'actions/notification_actions.jsx'; import {handleNewPost} from 'actions/post_actions'; import * as StatusActions from 'actions/status_actions'; -import {setGlobalItem} from 'actions/storage'; +import {removeGlobalItem, setGlobalItem} from 'actions/storage'; import {loadProfilesForDM, loadProfilesForGM} from 'actions/user_actions'; import {syncPostsInChannel} from 'actions/views/channel'; import {setGlobalDraft, transformServerDraft} from 'actions/views/drafts'; @@ -118,7 +118,7 @@ import RemovedFromChannelModal from 'components/removed_from_channel_modal'; import WebSocketClient from 'client/web_websocket_client'; import {loadPlugin, loadPluginsIfNecessary, removePlugin} from 'plugins'; import {getHistory} from 'utils/browser_history'; -import {ActionTypes, Constants, AnnouncementBarMessages, SocketEvents, UserStatuses, ModalIdentifiers, WarnMetricTypes, PageLoadContext} from 'utils/constants'; +import {ActionTypes, Constants, AnnouncementBarMessages, SocketEvents, UserStatuses, ModalIdentifiers, WarnMetricTypes, PageLoadContext, StoragePrefixes} from 'utils/constants'; import {getSiteURL} from 'utils/url'; import {temporarilySetPageLoadContext} from './telemetry_actions'; @@ -786,6 +786,19 @@ async function handlePostDeleteEvent(msg) { dispatch(postDeleted(post)); + // remove draft associated with this post from store + const draftKey = `${StoragePrefixes.COMMENT_DRAFT}${post.id}`; + + // update the draft first to re-render + await dispatch(setGlobalItem(draftKey, { + message: '', + fileInfos: [], + uploadsInProgress: [], + })); + + // then remove it + await dispatch(removeGlobalItem(draftKey)); + // update thread when a comment is deleted and CRT is on if (post.root_id && collapsedThreads) { const thread = getThread(state, post.root_id); @@ -1757,11 +1770,15 @@ function handleDeleteDraftEvent(msg) { const draft = JSON.parse(msg.data.draft); const {key} = transformServerDraft(draft); - doDispatch(setGlobalItem(key, { + // update the draft first to re-render + await doDispatch(setGlobalItem(key, { message: '', fileInfos: [], uploadsInProgress: [], })); + + // then remove it + await doDispatch(removeGlobalItem(key)); }; }