diff --git a/server/channels/store/sqlstore/post_store.go b/server/channels/store/sqlstore/post_store.go index 3e702b4b0a2..1582c1cc75b 100644 --- a/server/channels/store/sqlstore/post_store.go +++ b/server/channels/store/sqlstore/post_store.go @@ -691,7 +691,6 @@ func (s *SqlPostStore) Get(ctx context.Context, id string, opts model.GetPostsOp if id == "" { return nil, store.NewErrInvalidInput("Post", "id", id) } - var post model.Post postFetchQuery := "SELECT p.*, (SELECT count(*) FROM Posts WHERE Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0) as ReplyCount FROM Posts p WHERE p.Id = ? AND p.DeleteAt = 0" err := s.DBXFromContext(ctx).Get(&post, postFetchQuery, id) @@ -715,14 +714,40 @@ func (s *SqlPostStore) Get(ctx context.Context, id string, opts model.GetPostsOp return nil, errors.Wrapf(err, "invalid rootId with value=%s", rootId) } - query := s.getQueryBuilder(). - Select("p.*, (SELECT count(*) FROM Posts WHERE Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0) as ReplyCount"). - From("Posts p"). - Where(sq.Or{ - sq.Eq{"p.Id": rootId}, - sq.Eq{"p.RootId": rootId}, - }). - Where(sq.Eq{"p.DeleteAt": 0}) + var query sq.SelectBuilder + if s.DriverName() == model.DatabaseDriverMysql { + query = s.getQueryBuilder(). + Select("p.*, (SELECT count(*) FROM Posts WHERE Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0) as ReplyCount"). + From("Posts p"). + Where(sq.And{ + sq.Or{ + sq.Eq{"p.Id": rootId}, + sq.Eq{"p.RootId": rootId}, + }, + sq.Eq{"p.DeleteAt": 0}, + }) + } else { + query = s.getQueryBuilder(). + Select("p.*, replycount.num as ReplyCount"). + PrefixExpr(s.getQueryBuilder(). + Select(). + Prefix("WITH replycount as ("). + Columns("count(*) as num"). + From("posts"). + Where(sq.And{ + sq.Eq{"RootId": rootId}, + sq.Eq{"DeleteAt": 0}, + }).Suffix(")"), + ). + From("Posts p, replycount"). + Where(sq.And{ + sq.Or{ + sq.Eq{"p.Id": rootId}, + sq.Eq{"p.RootId": rootId}, + }, + sq.Eq{"p.DeleteAt": 0}, + }) + } var sort string if opts.Direction != "" { @@ -815,7 +840,7 @@ func (s *SqlPostStore) GetSingle(id string, inclDeleted bool) (*model.Post, erro Where(sq.Eq{"p.Id": id}) replyCountSubQuery := s.getQueryBuilder(). - Select("COUNT(Posts.Id)"). + Select("COUNT(*)"). From("Posts"). Where(sq.Expr("Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0")) diff --git a/server/channels/store/storetest/post_store.go b/server/channels/store/storetest/post_store.go index 0dde75add5c..646dc0bbbdc 100644 --- a/server/channels/store/storetest/post_store.go +++ b/server/channels/store/storetest/post_store.go @@ -821,6 +821,8 @@ func testPostStoreGetForThread(t *testing.T, rctx request.CTX, ss store.Store) { } r1, err = ss.Post().Get(context.Background(), o1.Id, opts, o1.UserId, map[string]bool{}) require.NoError(t, err) + require.Equal(t, r1.Posts[r1.Order[0]].ReplyCount, int64(4)) + require.Equal(t, r1.Posts[r1.Order[1]].ReplyCount, int64(4)) require.Len(t, r1.Order, 2) // including the root post require.Len(t, r1.Posts, 2) assert.LessOrEqual(t, r1.Posts[r1.Order[1]].CreateAt, m1.CreateAt) @@ -852,6 +854,10 @@ func testPostStoreGetForThread(t *testing.T, rctx request.CTX, ss store.Store) { } r1, err = ss.Post().Get(context.Background(), o1.Id, opts, o1.UserId, map[string]bool{}) require.NoError(t, err) + require.Equal(t, r1.Posts[r1.Order[0]].ReplyCount, int64(4)) + require.Equal(t, r1.Posts[r1.Order[1]].ReplyCount, int64(4)) + require.Equal(t, r1.Posts[r1.Order[2]].ReplyCount, int64(4)) + require.Equal(t, r1.Posts[r1.Order[3]].ReplyCount, int64(4)) require.Len(t, r1.Order, 4) // including the root post require.Len(t, r1.Posts, 4) assert.GreaterOrEqual(t, r1.Posts[r1.Order[len(r1.Order)-1]].CreateAt, lastPostCreateAt)