MM-68140: Validate post read access before rewrite thread context (#35864)

Ensure thread context for message rewrite is only built when the session
may read the anchor post, and surface context build failures to the client.

Made-with: Cursor
This commit is contained in:
Nick Misasi 2026-04-01 05:07:45 -04:00 committed by GitHub
parent 47d2c6074d
commit f4d1abe7e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 266 additions and 8 deletions

View file

@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/http/httptest"
@ -6792,3 +6793,46 @@ func TestCreateCardPostWithFeatureFlagDisabled(t *testing.T) {
assert.Equal(t, "", rpost.Type)
})
}
// MM-68140: POST /posts/rewrite must reject root_id for threads in channels the user cannot read
// before any thread content is used (e.g. sent to the AI bridge).
func TestRewritePostRequiresReadAccessToRootThread(t *testing.T) {
mainHelper.Parallel(t)
th := Setup(t).InitBasic(t)
dm := th.CreateDmChannel(t, th.BasicUser2)
secretToken := "MM68140_SECRET_REWRITE_HTTP_" + model.NewId()
root := th.CreateMessagePostWithClient(t, th.Client, dm, secretToken)
attacker := th.CreateUserWithClient(t, th.SystemAdminClient)
th.LinkUserToTeam(t, attacker, th.BasicTeam)
attackerClient := th.CreateClient()
_, _, err := attackerClient.Login(context.Background(), attacker.Email, attacker.Password)
require.NoError(t, err)
req := model.RewriteRequest{
AgentID: model.NewId(),
Message: "text to shorten",
Action: model.RewriteActionShorten,
RootID: root.Id,
}
reqBody, err := json.Marshal(req)
require.NoError(t, err)
// Use raw HTTP so we can read the body on non-2xx (DoAPIPostJSON returns an error and closes the body for status >= 300).
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodPost, attackerClient.APIURL+"/posts/rewrite", bytes.NewReader(reqBody))
require.NoError(t, err)
httpReq.Header.Set(model.HeaderAuth, attackerClient.AuthType+" "+attackerClient.AuthToken)
httpReq.Header.Set("Content-Type", "application/json")
respHTTP, err := attackerClient.HTTPClient.Do(httpReq)
require.NoError(t, err)
defer respHTTP.Body.Close()
bodyBytes, err := io.ReadAll(respHTTP.Body)
require.NoError(t, err)
resp := model.BuildResponse(respHTTP)
require.Equalf(t, http.StatusForbidden, resp.StatusCode,
"rewrite with root_id in an unreadable channel must return forbidden before using thread content; status=%d body=%s", resp.StatusCode, string(bodyBytes))
assert.NotContains(t, string(bodyBytes), secretToken, "response must not leak private thread content")
}

View file

@ -3229,11 +3229,9 @@ func (a *App) RewriteMessage(
if rootID != "" {
context, appErr := a.buildThreadContextForRewrite(rctx, rootID)
if appErr != nil {
// Log error but continue without context rather than failing the rewrite
rctx.Logger().Warn("Failed to build thread context for rewrite", mlog.String("root_id", rootID), mlog.Err(appErr))
} else {
threadContext = context
return nil, appErr
}
threadContext = context
}
userPrompt := getRewritePromptForAction(action, message, customPrompt, threadContext)
@ -3290,8 +3288,17 @@ func (a *App) RewriteMessage(
func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (string, *model.AppError) {
const maxContextPosts = 10
// Get the thread posts
postList, appErr := a.GetPostThread(rctx, rootID, model.GetPostsOptions{}, rctx.Session().UserId)
anchorPost, appErr, _ := a.GetPostIfAuthorized(rctx, rootID, rctx.Session(), false)
if appErr != nil {
return "", appErr
}
threadRootID := anchorPost.RootId
if threadRootID == "" {
threadRootID = anchorPost.Id
}
// Get the thread posts (only after confirming the session may read the anchor post's channel)
postList, appErr := a.GetPostThread(rctx, anchorPost.Id, model.GetPostsOptions{}, rctx.Session().UserId)
if appErr != nil {
return "", appErr
}
@ -3301,7 +3308,7 @@ func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (str
}
// Get root post
rootPost, ok := postList.Posts[rootID]
rootPost, ok := postList.Posts[threadRootID]
if !ok {
return "", nil
}
@ -3314,7 +3321,7 @@ func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (str
// Collect reply posts, filtering out system posts and deleted posts
var replies []*model.Post
for _, postID := range postList.Order {
if postID == rootID {
if postID == threadRootID {
continue // Skip root post
}
post, ok := postList.Posts[postID]

View file

@ -3860,6 +3860,213 @@ func TestGetPostIfAuthorized(t *testing.T) {
})
}
// MM-68140: thread context for rewrite must not be built from posts in channels the session cannot read.
func TestBuildThreadContextForRewriteRequiresChannelReadAccess(t *testing.T) {
mainHelper.Parallel(t)
th := Setup(t).InitBasic(t)
t.Run("direct message between other users", func(t *testing.T) {
secretToken := "MM68140_SECRET_DM_THREAD_" + model.NewId()
dm := th.CreateDmChannel(t, th.BasicUser2)
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: dm.Id,
Message: secretToken,
}, dm, model.CreatePostFlags{})
require.Nil(t, err)
_, _, err = th.App.CreatePost(th.Context, &model.Post{
RootId: root.Id,
UserId: th.BasicUser2.Id,
ChannelId: dm.Id,
Message: "reply only visible to DM participants",
}, dm, model.CreatePostFlags{})
require.Nil(t, err)
attacker := th.CreateUser(t)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: attacker.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.NotNil(t, appErr, "expected permission error when root_id is in a channel the user cannot read, got nil")
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
assert.NotContains(t, contextStr, secretToken)
})
t.Run("private channel the user is not a member of", func(t *testing.T) {
secretToken := "MM68140_SECRET_PRIVATE_THREAD_" + model.NewId()
privateCh := th.CreatePrivateChannel(t, th.BasicTeam)
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: privateCh.Id,
Message: secretToken,
}, privateCh, model.CreatePostFlags{})
require.Nil(t, err)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser2.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.NotNil(t, appErr)
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
assert.NotContains(t, contextStr, secretToken)
})
}
// MM-68140: additional edge cases for thread context authorization and anchor resolution.
func TestBuildThreadContextForRewriteEdgeCasesMM68140(t *testing.T) {
mainHelper.Parallel(t)
th := Setup(t).InitBasic(t)
t.Run("reply post id as root_id resolves thread and includes root message", func(t *testing.T) {
_, appErr := th.App.AddUserToChannel(th.Context, th.BasicUser2, th.BasicChannel, false)
require.Nil(t, appErr)
rootSecret := "MM68140_ROOT_VIA_REPLY_ANCHOR_" + model.NewId()
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: th.BasicChannel.Id,
Message: rootSecret,
}, th.BasicChannel, model.CreatePostFlags{})
require.Nil(t, err)
reply, _, err := th.App.CreatePost(th.Context, &model.Post{
RootId: root.Id,
UserId: th.BasicUser2.Id,
ChannelId: th.BasicChannel.Id,
Message: "reply anchor",
}, th.BasicChannel, model.CreatePostFlags{})
require.Nil(t, err)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser2.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, reply.Id)
require.Nil(t, appErr)
assert.Contains(t, contextStr, rootSecret)
assert.Contains(t, contextStr, "reply anchor")
})
t.Run("nonexistent post id returns not found", func(t *testing.T) {
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
_, appErr := th.App.buildThreadContextForRewrite(ctx, model.NewId())
require.NotNil(t, appErr)
assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
})
t.Run("soft-deleted anchor post returns not found", func(t *testing.T) {
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: th.BasicChannel.Id,
Message: "to be deleted",
}, th.BasicChannel, model.CreatePostFlags{})
require.Nil(t, err)
_, err = th.App.DeletePost(th.Context, root.Id, th.BasicUser.Id)
require.Nil(t, err)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
_, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.NotNil(t, appErr)
assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
})
t.Run("guest on team cannot use root_id for private channel they are not in", func(t *testing.T) {
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.GuestAccountsSettings.Enable = true
})
guest := th.CreateGuest(t)
_, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, guest.Id, "")
require.Nil(t, appErr)
privateCh := th.CreatePrivateChannel(t, th.BasicTeam)
secretToken := "MM68140_GUEST_PRIVATE_" + model.NewId()
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: privateCh.Id,
Message: secretToken,
}, privateCh, model.CreatePostFlags{})
require.Nil(t, err)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: guest.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.NotNil(t, appErr)
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
assert.NotContains(t, contextStr, secretToken)
})
t.Run("system admin may read thread context for DM they do not participate in", func(t *testing.T) {
dm := th.CreateDmChannel(t, th.BasicUser2)
secretToken := "MM68140_ADMIN_DM_THREAD_" + model.NewId()
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: dm.Id,
Message: secretToken,
}, dm, model.CreatePostFlags{})
require.Nil(t, err)
_, _, err = th.App.CreatePost(th.Context, &model.Post{
RootId: root.Id,
UserId: th.BasicUser2.Id,
ChannelId: dm.Id,
Message: "dm reply",
}, dm, model.CreatePostFlags{})
require.Nil(t, err)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.SystemAdminUser.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.Nil(t, appErr)
assert.Contains(t, contextStr, secretToken)
})
t.Run("member can build context after channel is archived", func(t *testing.T) {
ch := th.CreateChannel(t, th.BasicTeam)
root, _, err := th.App.CreatePost(th.Context, &model.Post{
UserId: th.BasicUser.Id,
ChannelId: ch.Id,
Message: "MM68140_ARCHIVED_ROOT",
}, ch, model.CreatePostFlags{})
require.Nil(t, err)
_, _, err = th.App.CreatePost(th.Context, &model.Post{
RootId: root.Id,
UserId: th.BasicUser.Id,
ChannelId: ch.Id,
Message: "reply in archived",
}, ch, model.CreatePostFlags{})
require.Nil(t, err)
appErr := th.App.DeleteChannel(th.Context, ch, th.SystemAdminUser.Id)
require.Nil(t, appErr)
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
require.Nil(t, err)
ctx := th.Context.WithSession(session)
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
require.Nil(t, appErr)
assert.Contains(t, contextStr, "MM68140_ARCHIVED_ROOT")
})
}
func TestShouldNotRefollowOnOthersReply(t *testing.T) {
mainHelper.Parallel(t)
th := Setup(t).InitBasic(t)