mirror of
https://github.com/mattermost/mattermost.git
synced 2026-05-28 04:35:04 -04:00
MM-68140: Validate post read access before rewrite thread context (#35864)
Ensure thread context for message rewrite is only built when the session may read the anchor post, and surface context build failures to the client. Made-with: Cursor
This commit is contained in:
parent
47d2c6074d
commit
f4d1abe7e8
3 changed files with 266 additions and 8 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -6792,3 +6793,46 @@ func TestCreateCardPostWithFeatureFlagDisabled(t *testing.T) {
|
|||
assert.Equal(t, "", rpost.Type)
|
||||
})
|
||||
}
|
||||
|
||||
// MM-68140: POST /posts/rewrite must reject root_id for threads in channels the user cannot read
|
||||
// before any thread content is used (e.g. sent to the AI bridge).
|
||||
func TestRewritePostRequiresReadAccessToRootThread(t *testing.T) {
|
||||
mainHelper.Parallel(t)
|
||||
th := Setup(t).InitBasic(t)
|
||||
|
||||
dm := th.CreateDmChannel(t, th.BasicUser2)
|
||||
secretToken := "MM68140_SECRET_REWRITE_HTTP_" + model.NewId()
|
||||
root := th.CreateMessagePostWithClient(t, th.Client, dm, secretToken)
|
||||
|
||||
attacker := th.CreateUserWithClient(t, th.SystemAdminClient)
|
||||
th.LinkUserToTeam(t, attacker, th.BasicTeam)
|
||||
|
||||
attackerClient := th.CreateClient()
|
||||
_, _, err := attackerClient.Login(context.Background(), attacker.Email, attacker.Password)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := model.RewriteRequest{
|
||||
AgentID: model.NewId(),
|
||||
Message: "text to shorten",
|
||||
Action: model.RewriteActionShorten,
|
||||
RootID: root.Id,
|
||||
}
|
||||
reqBody, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
// Use raw HTTP so we can read the body on non-2xx (DoAPIPostJSON returns an error and closes the body for status >= 300).
|
||||
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodPost, attackerClient.APIURL+"/posts/rewrite", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
httpReq.Header.Set(model.HeaderAuth, attackerClient.AuthType+" "+attackerClient.AuthToken)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
respHTTP, err := attackerClient.HTTPClient.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer respHTTP.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(respHTTP.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := model.BuildResponse(respHTTP)
|
||||
require.Equalf(t, http.StatusForbidden, resp.StatusCode,
|
||||
"rewrite with root_id in an unreadable channel must return forbidden before using thread content; status=%d body=%s", resp.StatusCode, string(bodyBytes))
|
||||
assert.NotContains(t, string(bodyBytes), secretToken, "response must not leak private thread content")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3229,11 +3229,9 @@ func (a *App) RewriteMessage(
|
|||
if rootID != "" {
|
||||
context, appErr := a.buildThreadContextForRewrite(rctx, rootID)
|
||||
if appErr != nil {
|
||||
// Log error but continue without context rather than failing the rewrite
|
||||
rctx.Logger().Warn("Failed to build thread context for rewrite", mlog.String("root_id", rootID), mlog.Err(appErr))
|
||||
} else {
|
||||
threadContext = context
|
||||
return nil, appErr
|
||||
}
|
||||
threadContext = context
|
||||
}
|
||||
|
||||
userPrompt := getRewritePromptForAction(action, message, customPrompt, threadContext)
|
||||
|
|
@ -3290,8 +3288,17 @@ func (a *App) RewriteMessage(
|
|||
func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (string, *model.AppError) {
|
||||
const maxContextPosts = 10
|
||||
|
||||
// Get the thread posts
|
||||
postList, appErr := a.GetPostThread(rctx, rootID, model.GetPostsOptions{}, rctx.Session().UserId)
|
||||
anchorPost, appErr, _ := a.GetPostIfAuthorized(rctx, rootID, rctx.Session(), false)
|
||||
if appErr != nil {
|
||||
return "", appErr
|
||||
}
|
||||
threadRootID := anchorPost.RootId
|
||||
if threadRootID == "" {
|
||||
threadRootID = anchorPost.Id
|
||||
}
|
||||
|
||||
// Get the thread posts (only after confirming the session may read the anchor post's channel)
|
||||
postList, appErr := a.GetPostThread(rctx, anchorPost.Id, model.GetPostsOptions{}, rctx.Session().UserId)
|
||||
if appErr != nil {
|
||||
return "", appErr
|
||||
}
|
||||
|
|
@ -3301,7 +3308,7 @@ func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (str
|
|||
}
|
||||
|
||||
// Get root post
|
||||
rootPost, ok := postList.Posts[rootID]
|
||||
rootPost, ok := postList.Posts[threadRootID]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -3314,7 +3321,7 @@ func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (str
|
|||
// Collect reply posts, filtering out system posts and deleted posts
|
||||
var replies []*model.Post
|
||||
for _, postID := range postList.Order {
|
||||
if postID == rootID {
|
||||
if postID == threadRootID {
|
||||
continue // Skip root post
|
||||
}
|
||||
post, ok := postList.Posts[postID]
|
||||
|
|
|
|||
|
|
@ -3860,6 +3860,213 @@ func TestGetPostIfAuthorized(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// MM-68140: thread context for rewrite must not be built from posts in channels the session cannot read.
|
||||
func TestBuildThreadContextForRewriteRequiresChannelReadAccess(t *testing.T) {
|
||||
mainHelper.Parallel(t)
|
||||
th := Setup(t).InitBasic(t)
|
||||
|
||||
t.Run("direct message between other users", func(t *testing.T) {
|
||||
secretToken := "MM68140_SECRET_DM_THREAD_" + model.NewId()
|
||||
dm := th.CreateDmChannel(t, th.BasicUser2)
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: dm.Id,
|
||||
Message: secretToken,
|
||||
}, dm, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
_, _, err = th.App.CreatePost(th.Context, &model.Post{
|
||||
RootId: root.Id,
|
||||
UserId: th.BasicUser2.Id,
|
||||
ChannelId: dm.Id,
|
||||
Message: "reply only visible to DM participants",
|
||||
}, dm, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
attacker := th.CreateUser(t)
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: attacker.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
|
||||
require.NotNil(t, appErr, "expected permission error when root_id is in a channel the user cannot read, got nil")
|
||||
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
|
||||
assert.NotContains(t, contextStr, secretToken)
|
||||
})
|
||||
|
||||
t.Run("private channel the user is not a member of", func(t *testing.T) {
|
||||
secretToken := "MM68140_SECRET_PRIVATE_THREAD_" + model.NewId()
|
||||
privateCh := th.CreatePrivateChannel(t, th.BasicTeam)
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: privateCh.Id,
|
||||
Message: secretToken,
|
||||
}, privateCh, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser2.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
|
||||
require.NotNil(t, appErr)
|
||||
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
|
||||
assert.NotContains(t, contextStr, secretToken)
|
||||
})
|
||||
}
|
||||
|
||||
// MM-68140: additional edge cases for thread context authorization and anchor resolution.
|
||||
func TestBuildThreadContextForRewriteEdgeCasesMM68140(t *testing.T) {
|
||||
mainHelper.Parallel(t)
|
||||
th := Setup(t).InitBasic(t)
|
||||
|
||||
t.Run("reply post id as root_id resolves thread and includes root message", func(t *testing.T) {
|
||||
_, appErr := th.App.AddUserToChannel(th.Context, th.BasicUser2, th.BasicChannel, false)
|
||||
require.Nil(t, appErr)
|
||||
|
||||
rootSecret := "MM68140_ROOT_VIA_REPLY_ANCHOR_" + model.NewId()
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: rootSecret,
|
||||
}, th.BasicChannel, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
reply, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
RootId: root.Id,
|
||||
UserId: th.BasicUser2.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: "reply anchor",
|
||||
}, th.BasicChannel, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser2.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, reply.Id)
|
||||
require.Nil(t, appErr)
|
||||
assert.Contains(t, contextStr, rootSecret)
|
||||
assert.Contains(t, contextStr, "reply anchor")
|
||||
})
|
||||
|
||||
t.Run("nonexistent post id returns not found", func(t *testing.T) {
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
_, appErr := th.App.buildThreadContextForRewrite(ctx, model.NewId())
|
||||
require.NotNil(t, appErr)
|
||||
assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("soft-deleted anchor post returns not found", func(t *testing.T) {
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: "to be deleted",
|
||||
}, th.BasicChannel, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = th.App.DeletePost(th.Context, root.Id, th.BasicUser.Id)
|
||||
require.Nil(t, err)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
_, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
require.NotNil(t, appErr)
|
||||
assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("guest on team cannot use root_id for private channel they are not in", func(t *testing.T) {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.GuestAccountsSettings.Enable = true
|
||||
})
|
||||
|
||||
guest := th.CreateGuest(t)
|
||||
_, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, guest.Id, "")
|
||||
require.Nil(t, appErr)
|
||||
|
||||
privateCh := th.CreatePrivateChannel(t, th.BasicTeam)
|
||||
secretToken := "MM68140_GUEST_PRIVATE_" + model.NewId()
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: privateCh.Id,
|
||||
Message: secretToken,
|
||||
}, privateCh, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: guest.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
require.NotNil(t, appErr)
|
||||
assert.Equal(t, http.StatusForbidden, appErr.StatusCode)
|
||||
assert.NotContains(t, contextStr, secretToken)
|
||||
})
|
||||
|
||||
t.Run("system admin may read thread context for DM they do not participate in", func(t *testing.T) {
|
||||
dm := th.CreateDmChannel(t, th.BasicUser2)
|
||||
secretToken := "MM68140_ADMIN_DM_THREAD_" + model.NewId()
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: dm.Id,
|
||||
Message: secretToken,
|
||||
}, dm, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
_, _, err = th.App.CreatePost(th.Context, &model.Post{
|
||||
RootId: root.Id,
|
||||
UserId: th.BasicUser2.Id,
|
||||
ChannelId: dm.Id,
|
||||
Message: "dm reply",
|
||||
}, dm, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.SystemAdminUser.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
require.Nil(t, appErr)
|
||||
assert.Contains(t, contextStr, secretToken)
|
||||
})
|
||||
|
||||
t.Run("member can build context after channel is archived", func(t *testing.T) {
|
||||
ch := th.CreateChannel(t, th.BasicTeam)
|
||||
root, _, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: ch.Id,
|
||||
Message: "MM68140_ARCHIVED_ROOT",
|
||||
}, ch, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
_, _, err = th.App.CreatePost(th.Context, &model.Post{
|
||||
RootId: root.Id,
|
||||
UserId: th.BasicUser.Id,
|
||||
ChannelId: ch.Id,
|
||||
Message: "reply in archived",
|
||||
}, ch, model.CreatePostFlags{})
|
||||
require.Nil(t, err)
|
||||
|
||||
appErr := th.App.DeleteChannel(th.Context, ch, th.SystemAdminUser.Id)
|
||||
require.Nil(t, appErr)
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{UserId: th.BasicUser.Id, Props: model.StringMap{}})
|
||||
require.Nil(t, err)
|
||||
ctx := th.Context.WithSession(session)
|
||||
|
||||
contextStr, appErr := th.App.buildThreadContextForRewrite(ctx, root.Id)
|
||||
require.Nil(t, appErr)
|
||||
assert.Contains(t, contextStr, "MM68140_ARCHIVED_ROOT")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldNotRefollowOnOthersReply(t *testing.T) {
|
||||
mainHelper.Parallel(t)
|
||||
th := Setup(t).InitBasic(t)
|
||||
|
|
|
|||
Loading…
Reference in a new issue