From 0c98113a1708a9a73867795d52f755b4ec160ef3 Mon Sep 17 00:00:00 2001 From: Harrison Healey Date: Thu, 14 May 2026 10:16:48 -0400 Subject: [PATCH 01/80] MM-65058 Make Direct Messages modal load GMs when needed (#36548) * Changed batchGetProfilesInChannel to batchGetProfilesInGroupChannel and have it use bulk API * MM-65058 Add useUserIdsInGroupChannel and use to populate Direct Messages modal * Address feedback * Run Prettier * Fix types --- .../channels/direct_channels_modal.ts | 62 +++++++ .../ui/components/channels/sidebar_left.ts | 2 + .../playwright/lib/src/ui/components/index.ts | 3 + .../playwright/lib/src/ui/pages/channels.ts | 11 ++ .../group_message_profiles.spec.ts | 163 ++++++++++++++++++ .../common/hooks/useUserIdsInGroupChannel.ts | 22 +++ .../drafts/draft_title/draft_title.tsx | 4 +- .../components/more_direct_channels/index.ts | 2 - .../list_item/list_item.tsx | 4 + .../more_direct_channels.test.tsx | 1 - .../more_direct_channels.tsx | 6 +- .../mattermost-redux/src/actions/users.ts | 12 +- 12 files changed, 276 insertions(+), 16 deletions(-) create mode 100644 e2e-tests/playwright/lib/src/ui/components/channels/direct_channels_modal.ts create mode 100644 e2e-tests/playwright/specs/functional/channels/direct_messages_modal/group_message_profiles.spec.ts create mode 100644 webapp/channels/src/components/common/hooks/useUserIdsInGroupChannel.ts diff --git a/e2e-tests/playwright/lib/src/ui/components/channels/direct_channels_modal.ts b/e2e-tests/playwright/lib/src/ui/components/channels/direct_channels_modal.ts new file mode 100644 index 00000000000..5ef6f38477a --- /dev/null +++ b/e2e-tests/playwright/lib/src/ui/components/channels/direct_channels_modal.ts @@ -0,0 +1,62 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {UserProfile} from '@mattermost/types/users'; +import {Locator, expect} from '@playwright/test'; + +export default class DirectChannelsModal { + readonly container; + + readonly goButton; + readonly results; + readonly searchInput; + + constructor(container: Locator) { + this.container = container; + + this.goButton = container.getByRole('button', {name: 'Go'}); + this.results = container.locator('.more-modal__list'); + this.searchInput = container.getByRole('combobox', {name: 'Search for people'}); + } + + async toBeVisible() { + await expect(this.container).toBeVisible(); + } + + async selectUser(user: UserProfile) { + await this.fillSearchInput(user.username); + + // This may fail if there's too many group channels containing the provided user + const row = this.results + .locator('.more-modal__row:not(:has(.more-modal__gm-icon))') + .getByText(`@${user.username}`, {exact: false}); + + await row.click(); + + await expect(this.container.getByRole('button', {name: `Remove ${user.username}`})).toBeVisible(); + } + + async toHaveNUsersSelected(count: number) { + await expect(this.results.locator('.react-select_multi-value')).toHaveCount(count); + } + + async goToChannel() { + await this.goButton.click(); + + await expect(this.container).not.toBeAttached(); + } + + async toHaveNResults(count: number) { + await expect(this.results.locator('.more-modal__row')).toHaveCount(count); + } + + async fillSearchInput(text: string) { + await this.searchInput.fill(text); + } + + async toHaveUserAsNthResult(user: UserProfile, index: number) { + const row = this.results.locator('.more-modal__row').nth(index); + + await expect(row).toContainText(`@${user.username}`); + } +} diff --git a/e2e-tests/playwright/lib/src/ui/components/channels/sidebar_left.ts b/e2e-tests/playwright/lib/src/ui/components/channels/sidebar_left.ts index 7dc94566b5b..7d1ba570170 100644 --- a/e2e-tests/playwright/lib/src/ui/components/channels/sidebar_left.ts +++ b/e2e-tests/playwright/lib/src/ui/components/channels/sidebar_left.ts @@ -11,6 +11,7 @@ export default class ChannelsSidebarLeft { readonly findChannelButton; readonly scheduledPostBadge; readonly unreadChannelFilter; + readonly openDirectMessageButton; constructor(container: Locator) { this.container = container; @@ -20,6 +21,7 @@ export default class ChannelsSidebarLeft { this.findChannelButton = container.getByRole('button', {name: 'Find Channels'}); this.scheduledPostBadge = container.locator('span.scheduledPostBadge'); this.unreadChannelFilter = container.locator('.SidebarFilters_filterButton'); + this.openDirectMessageButton = container.getByRole('button', {name: 'Write a direct message'}); } async toBeVisible() { diff --git a/e2e-tests/playwright/lib/src/ui/components/index.ts b/e2e-tests/playwright/lib/src/ui/components/index.ts index ca4e2a49196..babea003a09 100644 --- a/e2e-tests/playwright/lib/src/ui/components/index.ts +++ b/e2e-tests/playwright/lib/src/ui/components/index.ts @@ -21,6 +21,7 @@ import ChannelsSidebarRight from './channels/sidebar_right'; import DeletePostConfirmationDialog from './channels/delete_post_confirmation_dialog'; import DeletePostModal from './channels/delete_post_modal'; import DeleteScheduledPostModal from './channels/delete_scheduled_post_modal'; +import DirectChannelsModal from './channels/direct_channels_modal'; import DraftPost from './channels/draft_post'; import EmojiGifPicker from './channels/emoji_gif_picker'; import FindChannelsModal from './channels/find_channels_modal'; @@ -89,6 +90,7 @@ const components = { DeletePostConfirmationDialog, DeletePostModal, DeleteScheduledPostModal, + DirectChannelsModal, DraftPost, EmojiGifPicker, FindChannelsModal, @@ -172,6 +174,7 @@ export { FlagPostConfirmationDialog, NewChannelModal, BrowseChannelsModal, + DirectChannelsModal, GenericConfirmModal, InvitePeopleModal, MembersInvitedModal, diff --git a/e2e-tests/playwright/lib/src/ui/pages/channels.ts b/e2e-tests/playwright/lib/src/ui/pages/channels.ts index 36cbb4f7dc1..3a6db4afc68 100644 --- a/e2e-tests/playwright/lib/src/ui/pages/channels.ts +++ b/e2e-tests/playwright/lib/src/ui/pages/channels.ts @@ -38,6 +38,7 @@ export default class ChannelsPage { readonly findChannelsModal; readonly newChannelModal; readonly browseChannelsModal; + readonly directChannelsModal; public invitePeopleModal: InvitePeopleModal | undefined; public membersInvitedModal: MembersInvitedModal | undefined; readonly profileModal; @@ -77,6 +78,9 @@ export default class ChannelsPage { this.findChannelsModal = new components.FindChannelsModal(page.getByRole('dialog', {name: 'Find Channels'})); this.newChannelModal = new NewChannelModal(page.getByRole('dialog', {name: 'Create a new channel'})); this.browseChannelsModal = new BrowseChannelsModal(page.getByRole('dialog', {name: 'Browse Channels'})); + this.directChannelsModal = new components.DirectChannelsModal( + page.getByRole('dialog', {name: 'Direct Messages'}), + ); this.profileModal = new components.ProfileModal(page.getByRole('dialog', {name: 'Profile'})); this.settingsModal = new components.SettingsModal(page.getByRole('dialog', {name: 'Settings'})); this.teamSettingsModal = new components.TeamSettingsModal(page.getByRole('dialog', {name: 'Team Settings'})); @@ -242,6 +246,13 @@ export default class ChannelsPage { return this.browseChannelsModal; } + async openDirectChannelsModal() { + await this.sidebarLeft.openDirectMessageButton.click(); + await this.directChannelsModal.toBeVisible(); + + return this.directChannelsModal; + } + async openCreateTeamForm(): Promise { await this.sidebarLeft.teamMenuButton.click(); await this.teamMenu.toBeVisible(); diff --git a/e2e-tests/playwright/specs/functional/channels/direct_messages_modal/group_message_profiles.spec.ts b/e2e-tests/playwright/specs/functional/channels/direct_messages_modal/group_message_profiles.spec.ts new file mode 100644 index 00000000000..0f079c5ba40 --- /dev/null +++ b/e2e-tests/playwright/specs/functional/channels/direct_messages_modal/group_message_profiles.spec.ts @@ -0,0 +1,163 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {Channel} from '@mattermost/types/channels'; +import type {UserProfile} from '@mattermost/types/users'; +import type {Page} from '@playwright/test'; + +import {expect, test} from '@mattermost/playwright-lib'; + +/** + * @objective Verify that a group message whose channel has fallen out of the sidebar (because the user + * has more DMs/GMs than the configured "Number of direct messages to show" limit) still appears in the + * Direct Messages modal with its members fully loaded — i.e. with a non-zero member count and the + * participant usernames as its name. + */ +test( + "MM-65058 Direct Messages modal should load group members for GMs which haven't been loaded otherwise", + {tag: '@direct_messages'}, + async ({pw}) => { + const {adminClient, user, userClient, team} = await pw.initSetup({withDefaultProfileImage: false}); + + // Use a lower visible DM limit than the UI normally lets you use to speed up this test + const totalGms = 2; + const visibleLimit = 1; + + // # Limit the user's visible DMs/GMs in the sidebar so one GM falls off the sidebar + await userClient.savePreferences(user.id, [ + { + user_id: user.id, + category: 'sidebar_settings', + name: 'limit_visible_dms_gms', + value: visibleLimit.toString(), + }, + ]); + + // # Create enough users to populate 11 GMs with unique users + const users = []; + for (let i = 0; i < totalGms * 2; i++) { + const user = await pw.createNewUserProfile(adminClient, {prefix: `mm65058gm${i}`}); + users.push(user); + } + + // # Log the user in and open the channels page + const {page, channelsPage} = await pw.testBrowser.login(user); + await channelsPage.goto(team.name, 'town-square'); + await channelsPage.toBeVisible(); + + // # Create 11 GMs using the Direct Channels modal + const gmChannels = []; + for (let i = 0; i < totalGms; i++) { + const memberA = users[i * 2]; + const memberB = users[i * 2 + 1]; + + // # Open the modal + const dialog = await channelsPage.openDirectChannelsModal(); + + // # Select the users and create the channel + await dialog.selectUser(memberA); + await dialog.selectUser(memberB); + await dialog.goToChannel(); + + // # Make a post in the channel to ensure that it has a last_post_at value + await channelsPage.postMessage(`gm message ${i}`); + + // # Save the channel's information for later + gmChannels.push({ + channel: await getCurrentChannel(page), + members: [memberA, memberB], + }); + } + + const targetGm = gmChannels[0]; + const otherGms = gmChannels.slice(1); + + // # Refresh the app and go back to Town Square + await channelsPage.goto(team.name, 'town-square'); + + // * Verify the target GM is not present in the sidebar to ensure that the sidebar hasn't loaded it + await expect(page.locator(`#sidebarItem_${targetGm.channel.name}`)).toHaveCount(0); + + // * Wait until the other GMs are loaded and present in the sidebar + for (const otherGm of otherGms) { + const otherGmEntry = page.locator(`#sidebarItem_${otherGm.channel.name}`); + + await expect(otherGmEntry).toHaveCount(1); + await expect(otherGmEntry).toContainText(gmChannelDisplayName(otherGm.members)); + } + + // * Verify that the members of the target GM haven't been loaded and the members of other GMs have + await assertChannelUsersNotLoaded(page, targetGm.channel.id); + for (const otherGm of otherGms) { + await assertChannelUsersLoaded(page, otherGm.channel.id, otherGm.members); + } + + // # Open the Direct Messages modal again + const dialog = await channelsPage.openDirectChannelsModal(); + + // # Wait for the list to populate + const rows = dialog.container.locator('#multiSelectList .more-modal__row'); + await expect.poll(async () => rows.count()).toBeGreaterThanOrEqual(totalGms); + + // * Verify the modal contains an entry for every GM the user has, including the one that fell + // * out of the sidebar + for (const {channel, members} of gmChannels) { + // Each GM row renders the member usernames joined by ', '. We use the second member's + // username (which is unique per GM) to locate the corresponding row. + const usernameMarker = `@${members[1].username}`; + const gmRow = rows.filter({hasText: usernameMarker}); + + // * Verify the row is rendered + await expect(gmRow, `expected to find a row in the DM modal for GM ${channel.id}`).toHaveCount(1); + + // * Verify the GM icon shows the correct member count (channel members minus current user) + await expect( + gmRow.locator('.more-modal__gm-icon'), + `expected GM ${channel.id} to show a member count of ${members.length}`, + ).toHaveText(members.length.toString()); + + // * Verify the row's name section includes every participant's username + const nameContainer = gmRow.locator('.more-modal__name'); + for (const participant of members) { + await expect( + nameContainer, + `expected GM ${channel.id} to include @${participant.username} in its name`, + ).toContainText(`@${participant.username}`); + } + } + + // * Double check that the members of the target GM have been loaded now + await assertChannelUsersLoaded(page, targetGm.channel.id, targetGm.members); + }, +); + +async function getCurrentChannel(page: Page) { + return await page.evaluate( + 'store.getState().entities.channels.channels[store.getState().entities.channels.currentChannelId]', + ); +} + +function gmChannelDisplayName(users: UserProfile[]) { + return users + .toSorted((a, b) => { + return a.username.localeCompare(b.username, undefined, {numeric: true}); + }) + .map((user) => user.username) + .join(', '); +} + +async function assertChannelUsersLoaded(page: Page, channelId: string, expectedUsers: UserProfile[]) { + // profilesInChannel contains Sets which aren't serializable for return from page.evaluate + const loadedIds = await page.evaluate( + `Array.from(store.getState().entities.users.profilesInChannel['${channelId}'])`, + ); + + await expect(loadedIds).toHaveLength(expectedUsers.length); + await expect(loadedIds).toEqual(expect.arrayContaining(expectedUsers.map((user) => user.id))); +} + +async function assertChannelUsersNotLoaded(page: Page, channelId: string) { + const loadedIds = await page.evaluate(`store.getState().entities.users.profilesInChannel['${channelId}']`); + + await expect(loadedIds).toBeUndefined(); +} diff --git a/webapp/channels/src/components/common/hooks/useUserIdsInGroupChannel.ts b/webapp/channels/src/components/common/hooks/useUserIdsInGroupChannel.ts new file mode 100644 index 00000000000..15fc018155b --- /dev/null +++ b/webapp/channels/src/components/common/hooks/useUserIdsInGroupChannel.ts @@ -0,0 +1,22 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {UserProfile} from '@mattermost/types/users'; + +import {batchGetProfilesInGroupChannel} from 'mattermost-redux/actions/users'; +import {getUserIdsInChannels} from 'mattermost-redux/selectors/entities/users'; + +import type {GlobalState} from 'types/store'; + +import {makeUseEntity} from './useEntity'; + +/** + * Returns a Set of user IDs in a given group channel. Those users are loaded from the server when needed. + */ +export const useUserIdsInGroupChannel = makeUseEntity>({ + name: 'useUserIdsInGroupChannel', + fetch: (channelId: string) => batchGetProfilesInGroupChannel(channelId), + selector: (state: GlobalState, channelId: string) => { + return getUserIdsInChannels(state)[channelId]; + }, +}); diff --git a/webapp/channels/src/components/drafts/draft_title/draft_title.tsx b/webapp/channels/src/components/drafts/draft_title/draft_title.tsx index 8f45d6033e1..8afd8646cf2 100644 --- a/webapp/channels/src/components/drafts/draft_title/draft_title.tsx +++ b/webapp/channels/src/components/drafts/draft_title/draft_title.tsx @@ -8,7 +8,7 @@ import {useDispatch} from 'react-redux'; import type {Channel} from '@mattermost/types/channels'; import type {UserProfile} from '@mattermost/types/users'; -import {batchGetProfilesInChannel, getMissingProfilesByIds} from 'mattermost-redux/actions/users'; +import {batchGetProfilesInGroupChannel, getMissingProfilesByIds} from 'mattermost-redux/actions/users'; import Avatar from 'components/widgets/users/avatar'; @@ -51,7 +51,7 @@ function DraftTitle({ // The action uses a data loader so it is safe to call do this for multiple // scheduled posts for the same GM without causing any duplicate API calls. if (channel.type === Constants.GM_CHANNEL && !membersCount) { - dispatch(batchGetProfilesInChannel(channel.id)); + dispatch(batchGetProfilesInGroupChannel(channel.id)); } }, [channel.id, channel.type, dispatch, membersCount]); diff --git a/webapp/channels/src/components/more_direct_channels/index.ts b/webapp/channels/src/components/more_direct_channels/index.ts index f840e0639fd..9f0a78d5595 100644 --- a/webapp/channels/src/components/more_direct_channels/index.ts +++ b/webapp/channels/src/components/more_direct_channels/index.ts @@ -29,7 +29,6 @@ import { import {openDirectChannelToUserId, openGroupChannelToUserIds} from 'actions/channel_actions'; import {loadStatusesForProfilesList, loadProfilesMissingStatus} from 'actions/status_actions'; -import {loadProfilesForGroupChannels} from 'actions/user_actions'; import {setModalSearchTerm} from 'actions/views/search'; import type {GlobalState} from 'types/store'; @@ -98,7 +97,6 @@ function mapDispatchToProps(dispatch: Dispatch) { loadProfilesMissingStatus, getTotalUsersStats, loadStatusesForProfilesList, - loadProfilesForGroupChannels, openDirectChannelToUserId, openGroupChannelToUserIds, searchProfiles, diff --git a/webapp/channels/src/components/more_direct_channels/list_item/list_item.tsx b/webapp/channels/src/components/more_direct_channels/list_item/list_item.tsx index a6281aa36a9..18b39fefb66 100644 --- a/webapp/channels/src/components/more_direct_channels/list_item/list_item.tsx +++ b/webapp/channels/src/components/more_direct_channels/list_item/list_item.tsx @@ -5,6 +5,7 @@ import classNames from 'classnames'; import React, {useCallback} from 'react'; import {useIntl} from 'react-intl'; +import {useUserIdsInGroupChannel} from 'components/common/hooks/useUserIdsInGroupChannel'; import Timestamp from 'components/timestamp'; import UserDetails from './user_details'; @@ -97,6 +98,9 @@ export default ListItem; function GMDetails(props: {option: GroupChannel}) { const {option} = props; + // Indirectly populate option.profiles when needed + useUserIdsInGroupChannel(option.id); + return ( <>
diff --git a/webapp/channels/src/components/more_direct_channels/more_direct_channels.test.tsx b/webapp/channels/src/components/more_direct_channels/more_direct_channels.test.tsx index 2436746baeb..6500aaab943 100644 --- a/webapp/channels/src/components/more_direct_channels/more_direct_channels.test.tsx +++ b/webapp/channels/src/components/more_direct_channels/more_direct_channels.test.tsx @@ -74,7 +74,6 @@ describe('components/MoreDirectChannels', () => { searchGroupChannels: jest.fn().mockResolvedValue({data: true}), setModalSearchTerm: jest.fn().mockResolvedValue({data: true}), loadStatusesForProfilesList: jest.fn().mockResolvedValue({data: true}), - loadProfilesForGroupChannels: jest.fn().mockResolvedValue({data: true}), openDirectChannelToUserId: jest.fn().mockResolvedValue({data: {name: 'dm'}}), openGroupChannelToUserIds: jest.fn().mockResolvedValue({data: {name: 'group'}}), getTotalUsersStats: jest.fn().mockImplementation(() => { diff --git a/webapp/channels/src/components/more_direct_channels/more_direct_channels.tsx b/webapp/channels/src/components/more_direct_channels/more_direct_channels.tsx index b518eb037cc..7bdc31cce01 100644 --- a/webapp/channels/src/components/more_direct_channels/more_direct_channels.tsx +++ b/webapp/channels/src/components/more_direct_channels/more_direct_channels.tsx @@ -52,7 +52,6 @@ export type Props = { loadProfilesMissingStatus: (users: UserProfile[]) => void; getTotalUsersStats: () => void; loadStatusesForProfilesList: (users: UserProfile[]) => void; - loadProfilesForGroupChannels: (groupChannels: Channel[]) => void; openDirectChannelToUserId: (userId: string) => Promise; openGroupChannelToUserIds: (userIds: string[]) => Promise; searchProfiles: (term: string, options: any) => Promise>; @@ -157,16 +156,13 @@ export default class MoreDirectChannels extends React.PureComponent { this.setUsersLoadingState(true); - const [{data: profilesData}, {data: groupChannelsData}] = await Promise.all([ + const [{data: profilesData}] = await Promise.all([ this.props.actions.searchProfiles(searchTerm, {team_id: teamId}), this.props.actions.searchGroupChannels(searchTerm), ]); if (profilesData) { this.props.actions.loadStatusesForProfilesList(profilesData); } - if (groupChannelsData) { - this.props.actions.loadProfilesForGroupChannels(groupChannelsData); - } this.resetPaging(); this.setUsersLoadingState(false); }, diff --git a/webapp/channels/src/packages/mattermost-redux/src/actions/users.ts b/webapp/channels/src/packages/mattermost-redux/src/actions/users.ts index da15f5d6231..7545af3fc33 100644 --- a/webapp/channels/src/packages/mattermost-redux/src/actions/users.ts +++ b/webapp/channels/src/packages/mattermost-redux/src/actions/users.ts @@ -352,17 +352,17 @@ export function getProfilesInChannel(channelId: string, page: number, perPage: n }; } -export function batchGetProfilesInChannel(channelId: string): ActionFuncAsync> { +export function batchGetProfilesInGroupChannel(channelId: string): ActionFuncAsync> { return async (dispatch, getState, {loaders}: any) => { - if (!loaders.profilesInChannelLoader) { - loaders.profilesInChannelLoader = new DelayedDataLoader({ - fetchBatch: (channelIds) => dispatch(getProfilesInChannel(channelIds[0], 0)), - maxBatchSize: 1, + if (!loaders.profilesInGroupChannelLoader) { + loaders.profilesInGroupChannelLoader = new DelayedDataLoader({ + fetchBatch: (channelIds) => dispatch(getProfilesInGroupChannels(channelIds)), + maxBatchSize: General.MAX_GROUP_CHANNELS_FOR_PROFILES, wait: missingProfilesWait, }); } - await loaders.profilesInChannelLoader.queueAndWait([channelId]); + await loaders.profilesInGroupChannelLoader.queueAndWait([channelId]); return {}; }; } From d1fb57bc375db6e313596090438a6a2fa65cdfef Mon Sep 17 00:00:00 2001 From: Harrison Healey Date: Thu, 14 May 2026 10:19:03 -0400 Subject: [PATCH 02/80] Add .envrc to .gitignore (#36567) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7871bfe276b..be437655668 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ docker-compose.override.yaml .notice-work/ .aider* .env +.envrc .planning/ **/CLAUDE.local.md From 4aa1c58e37bd8fcc01455d4f7747a2495fccc20f Mon Sep 17 00:00:00 2001 From: David Krauser Date: Thu, 14 May 2026 12:01:49 -0400 Subject: [PATCH 03/80] ci: invalidate poisoned shard-timing cache and guard future saves (#36568) --- .github/workflows/server-ci.yml | 1 + .../workflows/server-test-merge-template.yml | 15 ++++- .github/workflows/server-test-template.yml | 10 +++- server/channels/api4/post_test.go | 1 + server/channels/app/migrations.go | 16 +++++- server/scripts/shard-split.js | 56 +++++++++++++++---- 6 files changed, 83 insertions(+), 16 deletions(-) diff --git a/.github/workflows/server-ci.yml b/.github/workflows/server-ci.yml index c26247c67d5..f6e6cf138bb 100644 --- a/.github/workflows/server-ci.yml +++ b/.github/workflows/server-ci.yml @@ -247,6 +247,7 @@ jobs: artifact-pattern: postgres-server-test-logs-shard-* artifact-name: postgres-server-test-logs save-timing-cache: true + all-shards-passed: ${{ needs.test-postgres-normal.result == 'success' }} test-elasticsearch-v8: name: Elasticsearch v8 Compatibility diff --git a/.github/workflows/server-test-merge-template.yml b/.github/workflows/server-test-merge-template.yml index b007cf0929c..c9f6866f854 100644 --- a/.github/workflows/server-test-merge-template.yml +++ b/.github/workflows/server-test-merge-template.yml @@ -16,6 +16,11 @@ on: required: false type: boolean default: false + all-shards-passed: + description: "Whether every upstream shard succeeded. Used to gate the timing-cache save so a single shard failure doesn't poison the cache with missing-package data." + required: false + type: boolean + default: false jobs: merge: @@ -79,11 +84,17 @@ jobs: echo "has_timing=false" >> "$GITHUB_OUTPUT" fi + # Only save when every upstream shard succeeded. If even one shard + # failed/was killed, its gotestsum.json is missing and the merged report + # has no timings for that shard's packages — saving that would poison + # future shard splits (missing packages default to 1ms, all bin-pack + # onto the lightest shard, overloading it and repeating the failure). - name: Save test timing cache - if: inputs.save-timing-cache && steps.timing-prep.outputs.has_timing == 'true' && github.ref_name == github.event.repository.default_branch + if: inputs.save-timing-cache && inputs.all-shards-passed && steps.timing-prep.outputs.has_timing == 'true' && github.ref_name == github.event.repository.default_branch uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | server/prev-report.xml server/prev-gotestsum.json - key: server-test-timing-master-${{ github.run_id }} + # The v2 prefix matches the v2 restore prefix in server-test-template.yml. + key: server-test-timing-v2-master-${{ github.run_id }} diff --git a/.github/workflows/server-test-template.yml b/.github/workflows/server-test-template.yml index 775ab55b592..46e5e797888 100644 --- a/.github/workflows/server-test-template.yml +++ b/.github/workflows/server-test-template.yml @@ -93,9 +93,15 @@ jobs: server/prev-gotestsum.json # Always restore from master — timing is only saved on the default # branch and is stable enough for shard balancing. - key: server-test-timing-master + # NOTE: the v2 prefix invalidates pre-existing caches that were + # poisoned by shard failures (a killed shard loses its gotestsum.json, + # so the merged report was missing those packages' timings; on the + # next run they all defaulted to 1ms and bin-packed onto the lightest + # shard, overloading it and perpetuating the cycle). See also the + # all-shards-passed guard in server-test-merge-template.yml. + key: server-test-timing-v2-master restore-keys: | - server-test-timing- + server-test-timing-v2- - name: Setup BUILD_IMAGE id: build diff --git a/server/channels/api4/post_test.go b/server/channels/api4/post_test.go index cf20d5e0c2d..75147a4b33e 100644 --- a/server/channels/api4/post_test.go +++ b/server/channels/api4/post_test.go @@ -5511,6 +5511,7 @@ func TestGetEditHistoryForPost(t *testing.T) { } func TestCreatePostNotificationsWithCRT(t *testing.T) { + t.Skip("flaky") mainHelper.Parallel(t) th := Setup(t).InitBasic(t) diff --git a/server/channels/app/migrations.go b/server/channels/app/migrations.go index 4888a194491..524d54ba784 100644 --- a/server/channels/app/migrations.go +++ b/server/channels/app/migrations.go @@ -844,13 +844,25 @@ func (s *Server) doSetupBoardsProperties() error { for _, property := range propertiesToCreate { if _, err := s.propertyService.CreatePropertyField(nil, property); err != nil { - return fmt.Errorf("failed to create boards property: %q, error: %w", property.Name, err) + // Another server may have won the race and created this field + // concurrently (e.g. parallel tests sharing a database pool). + // Tolerate that but propagate any other error. + if _, retryErr := s.propertyService.GetPropertyFieldByName(nil, group.ID, "", property.Name); retryErr != nil { + return fmt.Errorf("failed to create boards property: %q, error: %w", property.Name, err) + } } } if len(propertiesToUpdate) > 0 { if _, _, err := s.propertyService.UpdatePropertyFields(nil, group.ID, propertiesToUpdate); err != nil { - return fmt.Errorf("failed to update boards property fields: %w", err) + // Another server may have won the race and updated these fields + // concurrently (e.g. parallel tests sharing a database pool). + // Both servers write the same expected values, so tolerate the + // conflict but propagate any other error. + var conflictErr *store.ErrConflict + if !errors.As(err, &conflictErr) { + return fmt.Errorf("failed to update boards property fields: %w", err) + } } } diff --git a/server/scripts/shard-split.js b/server/scripts/shard-split.js index 198f6d3e40e..8bbe59740f8 100644 --- a/server/scripts/shard-split.js +++ b/server/scripts/shard-split.js @@ -40,6 +40,16 @@ const SHARD_INDEX = parseInt(process.env.SHARD_INDEX); const SHARD_TOTAL = parseInt(process.env.SHARD_TOTAL); const HEAVY_MS = 600000; // 600s (10 min): packages above this get test-level splitting +// Packages that should always be split test-by-test, even on a cold cache. +// Without timing data the splitter falls through to alphabetical round-robin, +// which places these adjacent on the same runner and overwhelms postgres. +// Forcing them heavy lets `go test -list` enumerate their tests so the +// bin-packer can spread them across all shards. +const KNOWN_HEAVY_PKGS = new Set([ + "github.com/mattermost/mattermost/server/v8/channels/api4", + "github.com/mattermost/mattermost/server/v8/channels/app", +]); + if (isNaN(SHARD_INDEX) || isNaN(SHARD_TOTAL) || SHARD_TOTAL < 1) { console.error("ERROR: SHARD_INDEX and SHARD_TOTAL must be set"); process.exit(1); @@ -107,19 +117,30 @@ const hasTimingData = Object.keys(pkgTimes).length > 0; const hasTestTiming = Object.keys(testTimes).length > 0; // ── Identify heavy packages ── -// Only split at test level if we have per-test timing data +// Split at test level for packages above HEAVY_MS (requires per-test timing) +// AND for the KNOWN_HEAVY_PKGS list (which uses go test -list discovery +// to enumerate tests when no timing cache exists). +// +// Both checks gate on allPkgs membership so stale entries from the cached +// pkgTimes (renamed/deleted packages from a prior run) can't end up in +// heavyPkgs — otherwise the post-discovery fallback would emit them as +// whole-package items for nonexistent packages. +const allPkgsSet = new Set(allPkgs); const heavyPkgs = new Set(); if (hasTestTiming) { for (const [pkg, ms] of Object.entries(pkgTimes)) { - if (ms > HEAVY_MS) heavyPkgs.add(pkg); + if (ms > HEAVY_MS && allPkgsSet.has(pkg)) heavyPkgs.add(pkg); } } +for (const pkg of allPkgs) { + if (KNOWN_HEAVY_PKGS.has(pkg)) heavyPkgs.add(pkg); +} if (heavyPkgs.size > 0) { console.log("Heavy packages (test-level splitting):"); for (const p of heavyPkgs) { - console.log( - ` ${(pkgTimes[p] / 1000).toFixed(0)}s ${p.split("/").pop()}`, - ); + const t = pkgTimes[p]; + const label = t ? `${(t / 1000).toFixed(0)}s` : "no-timing"; + console.log(` ${label} ${p.split("/").pop()}`); } } @@ -134,10 +155,10 @@ for (const pkg of allPkgs) { .map(([k, ms]) => ({ ms, type: "T", pkg, test: k.split("::")[1] })); if (tests.length > 0) { items.push(...tests); - } else { - // Shouldn't happen, but fall back to whole package - items.push({ ms: pkgTimes[pkg] || 1, type: "P", pkg }); } + // If no per-test timing exists, the discovery step below enumerates + // tests via `go test -list`. A final fallback to whole-package is + // added after discovery for packages where both lookups failed. } else { items.push({ ms: pkgTimes[pkg] || 1, type: "P", pkg }); } @@ -186,6 +207,18 @@ if (heavyPkgs.size > 0) { ); } } + // Ensure every heavy package has at least one item. A package can reach + // this point with zero items if it has no per-test timing AND `go test + // -list` failed (e.g. sqlstore on a cold cache). + for (const pkg of heavyPkgs) { + const hasItems = items.some((it) => it.pkg === pkg); + if (!hasItems) { + console.log( + ` ${pkg.split("/").pop()}: no per-test data, running as whole package`, + ); + items.push({ ms: pkgTimes[pkg] || 1, type: "P", pkg }); + } + } console.log("::endgroup::"); } @@ -199,8 +232,11 @@ const shards = Array.from({ length: SHARD_TOTAL }, () => ({ heavy: {}, })); -if (!hasTimingData) { - // Round-robin fallback when no timing data exists +if (!hasTimingData && heavyPkgs.size === 0) { + // Round-robin fallback only when we have *no* signal — no timing cache + // and no known-heavy packages to test-level-split. With heavyPkgs we + // can still bin-pack: discovered tests (ms=1000 each) drive the + // distribution and whole-package items (ms=1) fill in evenly. console.log("No timing data — using round-robin"); allPkgs.forEach((pkg, i) => { shards[i % SHARD_TOTAL].whole.push(pkg); From 9f1fe90b69853f5f6111011bbbe02da4b404b1cc Mon Sep 17 00:00:00 2001 From: David Krauser Date: Thu, 14 May 2026 12:46:07 -0400 Subject: [PATCH 04/80] Migrate CPA to the v2 Property System (#36180) --- .../api4/custom_profile_attributes.go | 479 +++--- .../api4/custom_profile_attributes_test.go | 1132 ++++++++++---- server/channels/api4/properties.go | 377 ++--- server/channels/api4/properties_test.go | 91 +- server/channels/app/access_control.go | 8 +- server/channels/app/access_control_masking.go | 3 +- .../app/access_control_masking_test.go | 16 +- server/channels/app/authorization_test.go | 31 +- server/channels/app/content_flagging.go | 4 +- .../channels/app/custom_profile_attributes.go | 326 ---- .../app/custom_profile_attributes_test.go | 1326 ++--------------- server/channels/app/migrations.go | 4 +- server/channels/app/migrations_test.go | 106 +- server/channels/app/plugin_api.go | 22 +- server/channels/app/plugin_api_test.go | 67 + server/channels/app/plugin_properties_test.go | 138 +- .../channels/app/properties/access_control.go | 1075 +++++++------ .../access_control_attribute_validation.go | 514 +++++++ ...ccess_control_attribute_validation_test.go | 1093 ++++++++++++++ .../properties/access_control_field_test.go | 208 ++- .../properties/access_control_value_test.go | 325 +++- server/channels/app/properties/field_limit.go | 87 ++ .../app/properties/field_limit_test.go | 84 ++ server/channels/app/properties/helper_test.go | 23 +- server/channels/app/properties/hooks.go | 455 ++++++ server/channels/app/properties/hooks_test.go | 637 ++++++++ .../channels/app/properties/license_check.go | 148 ++ .../app/properties/license_check_test.go | 140 ++ server/channels/app/properties/migrations.go | 4 +- .../channels/app/properties/property_field.go | 201 ++- .../app/properties/property_field_test.go | 139 +- .../channels/app/properties/property_value.go | 112 +- server/channels/app/properties/service.go | 26 +- .../properties/type_change_value_cleanup.go | 66 + .../type_change_value_cleanup_test.go | 216 +++ server/channels/app/property_errors.go | 77 + server/channels/app/property_errors_test.go | 146 ++ server/channels/app/property_field.go | 201 ++- server/channels/app/property_field_helpers.go | 43 + .../app/property_field_helpers_test.go | 102 ++ server/channels/app/property_field_test.go | 371 ++++- server/channels/app/property_value.go | 133 +- server/channels/app/property_value_test.go | 100 ++ server/channels/app/server.go | 72 +- server/channels/db/migrations/migrations.list | 4 + ...176_migrate_cpa_to_access_control.down.sql | 16 + ...00176_migrate_cpa_to_access_control.up.sql | 22 + ...ter_attribute_view_by_object_type.down.sql | 30 + ...ilter_attribute_view_by_object_type.up.sql | 35 + .../store/sqlstore/migration_000172_test.go | 331 ++++ .../store/sqlstore/property_field_store.go | 26 +- .../store/sqlstore/property_value_store.go | 9 + server/channels/store/store.go | 1 + .../store/storetest/attributes_store.go | 26 +- .../storetest/mocks/PropertyFieldStore.go | 28 + .../store/storetest/property_field_store.go | 13 +- .../store/storetest/property_value_store.go | 10 +- server/channels/testlib/store.go | 9 +- .../user_attributes_field_e2e_test.go | 99 +- .../user_attributes_value_e2e_test.go | 85 +- server/i18n/en.json | 176 +-- .../public/model/custom_profile_attributes.go | 252 +--- .../model/custom_profile_attributes_test.go | 958 ++---------- .../public/model/property_access_control.go | 19 + server/public/model/property_field.go | 6 +- .../model/property_field_attrs_validation.go | 192 +++ .../property_field_attrs_validation_test.go | 157 ++ server/public/model/property_group.go | 17 + server/public/model/property_value.go | 58 + server/public/model/property_value_test.go | 35 + .../user_properties_dot_menu.test.tsx | 32 + .../user_properties_dot_menu.tsx | 5 +- .../user_properties_utils.ts | 10 +- 73 files changed, 8876 insertions(+), 4713 deletions(-) delete mode 100644 server/channels/app/custom_profile_attributes.go create mode 100644 server/channels/app/properties/access_control_attribute_validation.go create mode 100644 server/channels/app/properties/access_control_attribute_validation_test.go create mode 100644 server/channels/app/properties/field_limit.go create mode 100644 server/channels/app/properties/field_limit_test.go create mode 100644 server/channels/app/properties/hooks.go create mode 100644 server/channels/app/properties/hooks_test.go create mode 100644 server/channels/app/properties/license_check.go create mode 100644 server/channels/app/properties/license_check_test.go create mode 100644 server/channels/app/properties/type_change_value_cleanup.go create mode 100644 server/channels/app/properties/type_change_value_cleanup_test.go create mode 100644 server/channels/app/property_errors.go create mode 100644 server/channels/app/property_errors_test.go create mode 100644 server/channels/app/property_field_helpers.go create mode 100644 server/channels/app/property_field_helpers_test.go create mode 100644 server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql create mode 100644 server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql create mode 100644 server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql create mode 100644 server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql create mode 100644 server/channels/store/sqlstore/migration_000172_test.go create mode 100644 server/public/model/property_field_attrs_validation.go create mode 100644 server/public/model/property_field_attrs_validation_test.go diff --git a/server/channels/api4/custom_profile_attributes.go b/server/channels/api4/custom_profile_attributes.go index a58729dc8ed..186f0845e90 100644 --- a/server/channels/api4/custom_profile_attributes.go +++ b/server/channels/api4/custom_profile_attributes.go @@ -9,6 +9,7 @@ package api4 import ( "encoding/json" + "maps" "net/http" "strings" @@ -31,37 +32,37 @@ func (api *API) InitCustomProfileAttributes() { } func listCPAFields(c *Context, w http.ResponseWriter, r *http.Request) { - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.listCPAFields", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + group, appErr := c.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + c.Err = appErr return } - rctx := app.RequestContextWithCallerID(c.AppContext, c.AppContext.Session().UserId) - fields, appErr := c.App.ListCPAFields(rctx) + pfs, appErr := c.App.SearchPropertyFields(rctx, group.ID, model.PropertyFieldSearchOpts{ + GroupID: group.ID, + ObjectType: model.PropertyFieldObjectTypeUser, + PerPage: model.AccessControlGroupFieldLimit + 5, + }) if appErr != nil { c.Err = appErr return } + fields, convErr := model.CPAFieldsFromPropertyFields(pfs) + if convErr != nil { + c.Err = model.NewAppError("listCPAFields", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(convErr) + return + } + if err := json.NewEncoder(w).Encode(fields); err != nil { c.Logger.Warn("Error while writing response", mlog.Err(err)) } } func createCPAField(c *Context, w http.ResponseWriter, r *http.Request) { - if !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) { - c.SetPermissionError(model.PermissionManageSystem) - return - } - - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.createCPAField", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) - return - } - var pf *model.CPAField - err := json.NewDecoder(r.Body).Decode(&pf) - if err != nil || pf == nil { + if err := json.NewDecoder(r.Body).Decode(&pf); err != nil || pf == nil { c.SetInvalidParamWithErr("property_field", err) return } @@ -72,42 +73,71 @@ func createCPAField(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) model.AddEventParameterAuditableToAuditRec(auditRec, "property_field", pf) - rctx := app.RequestContextWithCallerID(c.AppContext, c.AppContext.Session().UserId) - createdField, appErr := c.App.CreateCPAField(rctx, pf) - if appErr != nil { - c.Err = appErr - return - } - - auditRec.Success() - auditRec.AddEventResultState(createdField) - auditRec.AddEventObjectType("property_field") - - w.WriteHeader(http.StatusCreated) - if err := json.NewEncoder(w).Encode(createdField); err != nil { - c.Logger.Warn("Error while writing response", mlog.Err(err)) - } -} - -func patchCPAField(c *Context, w http.ResponseWriter, r *http.Request) { + // CPA fields are system-scoped; only a system administrator may create + // them. This mirrors the scope-based permission check the shared generic + // handler enforces for system-typed fields. if !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) { c.SetPermissionError(model.PermissionManageSystem) return } - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.patchCPAField", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) + // Translate to PropertyField and route through the generic property API. + // Server-controlled fields (group, type, target shape, creator) are + // stamped here; ID/TargetID/Protected are stripped so a caller can't + // inject them. Permissions and timestamps are filled in by lower layers. + field := pf.ToPropertyField() + group, appErr := c.App.GetPropertyGroup(c.AppContext, model.AccessControlPropertyGroupName) + if appErr != nil { + c.Err = appErr + return + } + field.ID = "" + field.GroupID = group.ID + field.ObjectType = model.PropertyFieldObjectTypeUser + field.TargetType = string(model.PropertyFieldTargetLevelSystem) + field.TargetID = "" + field.Protected = false + field.CreatedBy = c.AppContext.Session().UserId + field.UpdatedBy = c.AppContext.Session().UserId + + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + connectionID := r.Header.Get(model.ConnectionId) + + createdField, appErr := c.App.CreatePropertyField(rctx, field, false, connectionID) + if appErr != nil { + c.Err = appErr return } + cpaField, convErr := model.NewCPAFieldFromPropertyField(createdField) + if convErr != nil { + c.Err = model.NewAppError("createCPAField", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(convErr) + return + } + + // Send CPA-specific websocket event for backwards compatibility + message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldCreated, "", "", "", nil, "") + message.Add("field", cpaField) + c.App.Publish(message) + + auditRec.AddEventObjectType("property_field") + auditRec.AddEventResultState(cpaField) + auditRec.Success() + + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(cpaField); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func patchCPAField(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireFieldId() if c.Err != nil { return } var patch *model.PropertyFieldPatch - err := json.NewDecoder(r.Body).Decode(&patch) - if err != nil || patch == nil { + if err := json.NewDecoder(r.Body).Decode(&patch); err != nil || patch == nil { c.SetInvalidParamWithErr("property_field_patch", err) return } @@ -115,11 +145,15 @@ func patchCPAField(c *Context, w http.ResponseWriter, r *http.Request) { if patch.Name != nil { *patch.Name = strings.TrimSpace(*patch.Name) } + // Target fields are server-controlled; prevent the caller from patching them. + patch.TargetID = nil + patch.TargetType = nil + if err := patch.IsValid(); err != nil { if appErr, ok := err.(*model.AppError); ok { c.Err = appErr } else { - c.Err = model.NewAppError("createCPAField", "api.custom_profile_attributes.invalid_field_patch", nil, "", http.StatusBadRequest) + c.Err = model.NewAppError("patchCPAField", "api.custom_profile_attributes.invalid_field_patch", nil, "", http.StatusBadRequest) } return } @@ -128,41 +162,86 @@ func patchCPAField(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) model.AddEventParameterAuditableToAuditRec(auditRec, "property_field_patch", patch) - rctx := app.RequestContextWithCallerID(c.AppContext, c.AppContext.Session().UserId) - originalField, appErr := c.App.GetCPAField(rctx, c.Params.FieldId) + group, appErr := c.App.GetPropertyGroup(c.AppContext, model.AccessControlPropertyGroupName) if appErr != nil { c.Err = appErr return } - auditRec.AddEventPriorState(originalField) + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) - patchedField, appErr := c.App.PatchCPAField(rctx, c.Params.FieldId, patch) + existingField, appErr := c.App.GetPropertyField(rctx, group.ID, c.Params.FieldId) if appErr != nil { c.Err = appErr return } + if existingField.ObjectType != model.PropertyFieldObjectTypeUser { + c.Err = model.NewAppError("patchCPAField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) + return + } + + // Permission branching (session-bound). + isOptionsOnly := isOptionsOnlyPatch(patch) + if isOptionsOnly && existingField.Type != model.PropertyFieldTypeSelect && existingField.Type != model.PropertyFieldTypeMultiselect { + isOptionsOnly = false + } + if isOptionsOnly { + if !c.App.SessionHasPermissionToManagePropertyFieldOptions(rctx, *c.AppContext.Session(), existingField) { + c.Err = model.NewAppError("patchCPAField", "api.property_field.update.no_options_permission.app_error", nil, "", http.StatusForbidden) + return + } + } else { + if !c.App.SessionHasPermissionToEditPropertyField(rctx, *c.AppContext.Session(), existingField) { + c.Err = model.NewAppError("patchCPAField", "api.property_field.update.no_field_permission.app_error", nil, "", http.StatusForbidden) + return + } + } + + // Capture original state for audit before in-place patch (Attrs is + // shallow-copied because Patch mutates it). + orig := *existingField + if existingField.Attrs != nil { + orig.Attrs = make(model.StringInterface, len(existingField.Attrs)) + maps.Copy(orig.Attrs, existingField.Attrs) + } + auditRec.AddEventPriorState(&orig) + + existingField.Patch(patch, true) + existingField.UpdatedBy = c.AppContext.Session().UserId + connectionID := r.Header.Get(model.ConnectionId) + + updatedField, clearedIDs, updateErr := c.App.UpdatePropertyField(rctx, group.ID, existingField, false, connectionID) + if updateErr != nil { + c.Err = updateErr + return + } + + cpaField, convErr := model.NewCPAFieldFromPropertyField(updatedField) + if convErr != nil { + c.Err = model.NewAppError("patchCPAField", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(convErr) + return + } + + // CPA-specific websocket event (backward compat). delete_values:true tells + // pre-PSAv2 webapp clients to clear cached values for this field; PSAv2 + // clients receive the same signal via WebsocketEventPropertyValuesUpdated + // fired by App.UpdatePropertyField. + message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldUpdated, "", "", "", nil, "") + message.Add("field", cpaField) + message.Add("delete_values", len(clearedIDs) > 0) + c.App.Publish(message) + auditRec.Success() - auditRec.AddEventResultState(patchedField) + auditRec.AddEventResultState(cpaField) auditRec.AddEventObjectType("property_field") - if err := json.NewEncoder(w).Encode(patchedField); err != nil { + if err := json.NewEncoder(w).Encode(cpaField); err != nil { c.Logger.Warn("Error while writing response", mlog.Err(err)) } } func deleteCPAField(c *Context, w http.ResponseWriter, r *http.Request) { - if !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) { - c.SetPermissionError(model.PermissionManageSystem) - return - } - - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.deleteCPAField", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) - return - } - c.RequireFieldId() if c.Err != nil { return @@ -172,54 +251,185 @@ func deleteCPAField(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "field_id", c.Params.FieldId) - rctx := app.RequestContextWithCallerID(c.AppContext, c.AppContext.Session().UserId) - field, appErr := c.App.GetCPAField(rctx, c.Params.FieldId) + group, appErr := c.App.GetPropertyGroup(c.AppContext, model.AccessControlPropertyGroupName) if appErr != nil { c.Err = appErr return } - auditRec.AddEventPriorState(field) - if appErr := c.App.DeleteCPAField(rctx, c.Params.FieldId); appErr != nil { + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + + existingField, appErr := c.App.GetPropertyField(rctx, group.ID, c.Params.FieldId) + if appErr != nil { c.Err = appErr return } + if existingField.ObjectType != model.PropertyFieldObjectTypeUser { + c.Err = model.NewAppError("deleteCPAField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) + return + } + + if !c.App.SessionHasPermissionToEditPropertyField(rctx, *c.AppContext.Session(), existingField) { + c.Err = model.NewAppError("deleteCPAField", "api.property_field.delete.no_permission.app_error", nil, "", http.StatusForbidden) + return + } + + connectionID := r.Header.Get(model.ConnectionId) + if deleteErr := c.App.DeletePropertyField(rctx, group.ID, c.Params.FieldId, false, connectionID); deleteErr != nil { + c.Err = deleteErr + return + } + + // CPA-specific websocket event (backward compat) + message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldDeleted, "", "", "", nil, "") + message.Add("field_id", c.Params.FieldId) + c.App.Publish(message) + + auditRec.AddEventPriorState(existingField) auditRec.Success() - auditRec.AddEventResultState(field) + auditRec.AddEventResultState(existingField) auditRec.AddEventObjectType("property_field") ReturnStatusOK(w) } func getCPAGroup(c *Context, w http.ResponseWriter, r *http.Request) { - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.getCPAGroup", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) + // Every other CPA endpoint enforces MinimumEnterpriseLicense via the + // LicenseCheckHook on field/value operations. GetPropertyGroup is not + // hooked, so we enforce the same contract here inline. + if !model.MinimumEnterpriseLicense(c.App.License()) { + c.Err = model.NewAppError("getCPAGroup", "app.property.license_error", nil, "an Enterprise license is required", http.StatusForbidden) return } - groupID, appErr := c.App.CpaGroupID() + group, appErr := c.App.GetPropertyGroup(c.AppContext, model.AccessControlPropertyGroupName) if appErr != nil { c.Err = appErr return } - if err := json.NewEncoder(w).Encode(map[string]string{"id": groupID}); err != nil { + if err := json.NewEncoder(w).Encode(map[string]string{"id": group.ID}); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +// cpaPatchValues is the shared implementation for patchCPAValues and +// patchCPAValuesForUser. It translates the CPA request format to the generic +// property API, performs the same session-bound checks as the generic value +// patch handler (target access, batch caps, per-field permission), routes +// the upsert through App.UpsertPropertyValues, and emits the CPA-specific +// websocket event. +func cpaPatchValues(c *Context, w http.ResponseWriter, r *http.Request, userID string, updates map[string]json.RawMessage) { + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + group, appErr := c.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + c.Err = appErr + return + } + + if !hasTargetAccess(c, model.PropertyFieldObjectTypeUser, userID, true) { + return + } + + // Translate CPA format → generic PropertyValuePatchItem list. Map + // iteration is unordered, but FieldID uniqueness is guaranteed by the + // JSON object key constraint, so we cannot hit duplicate-FieldID; still, + // we keep the same shape as the generic handler for parity. + items := make([]model.PropertyValuePatchItem, 0, len(updates)) + for fieldID, value := range updates { + items = append(items, model.PropertyValuePatchItem{ + FieldID: fieldID, + Value: value, + }) + } + + if len(items) == 0 { + c.Err = model.NewAppError("cpaPatchValues", "api.property_value.patch.empty_body.app_error", nil, "", http.StatusBadRequest) + return + } + if len(items) > maxPropertyValuePatchItems { + c.Err = model.NewAppError("cpaPatchValues", "api.property_value.patch.too_many_items.request_error", map[string]any{ + "Max": maxPropertyValuePatchItems, + }, "", http.StatusBadRequest) + return + } + + fieldIDs := make([]string, 0, len(items)) + for _, item := range items { + if !model.IsValidId(item.FieldID) { + c.Err = model.NewAppError("cpaPatchValues", "api.property_value.patch.invalid_field_id.app_error", nil, "", http.StatusBadRequest) + return + } + fieldIDs = append(fieldIDs, item.FieldID) + } + + fields, fieldsErr := c.App.GetPropertyFields(rctx, group.ID, fieldIDs) + if fieldsErr != nil { + c.Err = fieldsErr + return + } + fieldByID := make(map[string]*model.PropertyField, len(fields)) + for _, f := range fields { + fieldByID[f.ID] = f + } + for _, item := range items { + f, ok := fieldByID[item.FieldID] + if !ok { + c.Err = model.NewAppError("cpaPatchValues", "api.property_value.patch.field_not_found.app_error", + map[string]any{"FieldID": item.FieldID}, "", http.StatusNotFound) + return + } + if f.ObjectType != model.PropertyFieldObjectTypeUser { + c.Err = model.NewAppError("cpaPatchValues", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) + return + } + if !c.App.SessionHasPermissionToSetPropertyFieldValues(rctx, *c.AppContext.Session(), f) { + c.Err = model.NewAppError("cpaPatchValues", "api.property_value.patch.no_values_permission.app_error", nil, "", http.StatusForbidden) + return + } + } + + callerID := c.AppContext.Session().UserId + values := make([]*model.PropertyValue, len(items)) + for i, item := range items { + values[i] = &model.PropertyValue{ + TargetID: userID, + TargetType: model.PropertyFieldObjectTypeUser, + GroupID: group.ID, + FieldID: item.FieldID, + Value: item.Value, + CreatedBy: callerID, + UpdatedBy: callerID, + } + } + connectionID := r.Header.Get(model.ConnectionId) + + upserted, upsertErr := c.App.UpsertPropertyValues(rctx, values, model.PropertyFieldObjectTypeUser, userID, connectionID) + if upsertErr != nil { + c.Err = upsertErr + return + } + + // Translate response to CPA format: {fieldID: value} + results := make(map[string]json.RawMessage, len(upserted)) + for _, value := range upserted { + results[value.FieldID] = value.Value + } + + // CPA-specific websocket event (backward compat) + message := model.NewWebSocketEvent(model.WebsocketEventCPAValuesUpdated, "", "", "", nil, "") + message.Add("user_id", userID) + message.Add("values", results) + c.App.Publish(message) + + if err := json.NewEncoder(w).Encode(results); err != nil { c.Logger.Warn("Error while writing response", mlog.Err(err)) } } func patchCPAValues(c *Context, w http.ResponseWriter, r *http.Request) { - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.patchCPAValues", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) - return - } - userID := c.AppContext.Session().UserId - if !c.App.SessionHasPermissionToUser(*c.AppContext.Session(), userID) { - c.SetPermissionError(model.PermissionEditOtherUsers) - return - } var updates map[string]json.RawMessage if err := json.NewDecoder(r.Body).Decode(&updates); err != nil { @@ -231,72 +441,38 @@ func patchCPAValues(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "user_id", userID) - // if the user is not an admin, we need to check that there are no - // admin-managed fields - session := *c.AppContext.Session() - rctx := app.RequestContextWithCallerID(c.AppContext, session.UserId) - - if !c.App.SessionHasPermissionTo(session, model.PermissionManageSystem) { - fields, appErr := c.App.ListCPAFields(rctx) - if appErr != nil { - c.Err = appErr - return - } - - // Check if any of the fields being updated are admin-managed - for _, field := range fields { - if _, isBeingUpdated := updates[field.ID]; isBeingUpdated { - if field.IsAdminManaged() { - c.Err = model.NewAppError("Api4.patchCPAValues", "app.custom_profile_attributes.property_field_is_managed.app_error", nil, "", http.StatusForbidden) - return - } - } - } - } - - results := make(map[string]json.RawMessage, len(updates)) - for fieldID, rawValue := range updates { - patchedValue, appErr := c.App.PatchCPAValue(rctx, userID, fieldID, rawValue, false) - if appErr != nil { - c.Err = appErr - return - } - results[fieldID] = patchedValue.Value + cpaPatchValues(c, w, r, userID, updates) + if c.Err != nil { + return } auditRec.Success() auditRec.AddEventObjectType("patchCPAValues") - - if err := json.NewEncoder(w).Encode(results); err != nil { - c.Logger.Warn("Error while writing response", mlog.Err(err)) - } } func listCPAValues(c *Context, w http.ResponseWriter, r *http.Request) { - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.listCPAValues", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) - return - } - c.RequireUserId() if c.Err != nil { return } - targetUserID := c.Params.UserId - callerUserID := c.AppContext.Session().UserId - - // we check unrestricted sessions to allow local mode requests to go through - if !c.AppContext.Session().IsUnrestricted() { - canSee, err := c.App.UserCanSeeOtherUser(c.AppContext, callerUserID, targetUserID) - if err != nil || !canSee { - c.SetPermissionError(model.PermissionViewMembers) - return - } + if !hasTargetAccess(c, model.PropertyFieldObjectTypeUser, c.Params.UserId, false) { + return } - rctx := app.RequestContextWithCallerID(c.AppContext, callerUserID) - values, appErr := c.App.ListCPAValues(rctx, targetUserID) + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + group, appErr := c.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + c.Err = appErr + return + } + + values, appErr := c.App.SearchPropertyValues(rctx, group.ID, model.PropertyValueSearchOpts{ + TargetIDs: []string{c.Params.UserId}, + TargetType: model.PropertyValueTargetTypeUser, + // Single-target search: at most one value per (target, field), so the field cap bounds the page. + PerPage: model.AccessControlGroupFieldLimit + 5, + }) if appErr != nil { c.Err = appErr return @@ -312,23 +488,12 @@ func listCPAValues(c *Context, w http.ResponseWriter, r *http.Request) { } func patchCPAValuesForUser(c *Context, w http.ResponseWriter, r *http.Request) { - if !model.MinimumEnterpriseLicense(c.App.Channels().License()) { - c.Err = model.NewAppError("Api4.patchCPAValuesForUser", "api.custom_profile_attributes.license_error", nil, "", http.StatusForbidden) - return - } - - // Get userID from URL c.RequireUserId() if c.Err != nil { return } userID := c.Params.UserId - if !c.App.SessionHasPermissionToUser(*c.AppContext.Session(), userID) { - c.SetPermissionError(model.PermissionEditOtherUsers) - return - } - var updates map[string]json.RawMessage if err := json.NewDecoder(r.Body).Decode(&updates); err != nil { c.SetInvalidParamWithErr("value", err) @@ -339,47 +504,11 @@ func patchCPAValuesForUser(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "user_id", userID) - // Check for admin-managed fields - session := *c.AppContext.Session() - rctx := app.RequestContextWithCallerID(c.AppContext, session.UserId) - - isAdmin := c.App.SessionHasPermissionTo(session, model.PermissionManageSystem) - if !isAdmin { - fields, appErr := c.App.ListCPAFields(rctx) - if appErr != nil { - c.Err = appErr - return - } - - for _, field := range fields { - if _, isBeingUpdated := updates[field.ID]; !isBeingUpdated { - continue - } - // Check for admin-managed fields - if field.IsAdminManaged() { - c.Err = model.NewAppError("Api4.patchCPAValuesForUser", - "app.custom_profile_attributes.property_field_is_managed.app_error", - nil, "", - http.StatusForbidden) - return - } - } - } - - results := make(map[string]json.RawMessage, len(updates)) - for fieldID, rawValue := range updates { - patchedValue, appErr := c.App.PatchCPAValue(rctx, userID, fieldID, rawValue, false) - if appErr != nil { - c.Err = appErr - return - } - results[fieldID] = patchedValue.Value + cpaPatchValues(c, w, r, userID, updates) + if c.Err != nil { + return } auditRec.Success() auditRec.AddEventObjectType("patchCPAValues") - - if err := json.NewEncoder(w).Encode(results); err != nil { - c.Logger.Warn("Error while writing response", mlog.Err(err)) - } } diff --git a/server/channels/api4/custom_profile_attributes_test.go b/server/channels/api4/custom_profile_attributes_test.go index c6259b033fe..37869c8e603 100644 --- a/server/channels/api4/custom_profile_attributes_test.go +++ b/server/channels/api4/custom_profile_attributes_test.go @@ -16,6 +16,10 @@ import ( "github.com/stretchr/testify/require" ) +// celSafeName returns a CPA field name guaranteed to satisfy the CEL identifier +// rule the AccessControlAttributeValidationHook enforces. model.NewId() uses a base32 +// alphabet that includes digits, so a raw NewId can start with a digit and trip +// the ^[A-Za-z_]… pattern; the leading "f_" sidesteps that. func celSafeName() string { return "f_" + model.NewId() } @@ -32,7 +36,7 @@ func TestCreateCPAField(t *testing.T) { createdField, resp, err := client.CreateCPAField(context.Background(), field) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, createdField) }, "endpoint should not work if no valid license is present") @@ -116,6 +120,66 @@ func TestCreateCPAField(t *testing.T) { require.Equal(t, "admin", createdManagedField.Attrs[model.CustomProfileAttributesPropertyAttrsManaged]) require.Equal(t, "when_set", createdManagedField.Attrs["visibility"]) }, "admin should be able to create a managed field") + + t.Run("server zeroes DeleteAt even if input has a non-zero value", func(t *testing.T) { + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + DeleteAt: time.Now().UnixMilli(), + } + require.NotZero(t, field.DeleteAt) + + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.Zero(t, created.DeleteAt) + }) +} + +func TestCPAFieldLimit(t *testing.T) { + mainHelper.Parallel(t) + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.CustomProfileAttributes = true + }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + // Create 20 fields — the maximum allowed by FieldLimitHook. + createdIDs := make([]string, 0, 20) + for i := 1; i <= 20; i++ { + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + createdIDs = append(createdIDs, created.ID) + } + + t.Run("creating a 21st field is rejected", func(t *testing.T) { + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + } + _, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckUnprocessableEntityStatus(t, resp) + require.Error(t, err) + }) + + t.Run("deleted fields do not count toward the limit", func(t *testing.T) { + resp, err := th.SystemAdminClient.DeleteCPAField(context.Background(), createdIDs[0]) + CheckOKStatus(t, resp) + require.NoError(t, err) + + replacement := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), replacement) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotZero(t, created.ID) + }) } func TestListCPAFields(t *testing.T) { @@ -124,28 +188,31 @@ func TestListCPAFields(t *testing.T) { cfg.FeatureFlags.CustomProfileAttributes = true }) - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + // License required for field creation (LicenseCheckHook) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, Attrs: map[string]any{"visibility": "when_set"}, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) t.Run("endpoint should not work if no valid license is present", func(t *testing.T) { + th.App.Srv().SetLicense(nil) + defer th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + fields, resp, err := th.Client.ListCPAFields(context.Background()) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, fields) }) - // add a valid license - th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - t.Run("any user should be able to list fields", func(t *testing.T) { fields, resp, err := th.Client.ListCPAFields(context.Background()) CheckOKStatus(t, resp) @@ -156,7 +223,10 @@ func TestListCPAFields(t *testing.T) { }) t.Run("the endpoint should only list non deleted fields", func(t *testing.T) { - require.Nil(t, th.App.DeleteCPAField(request.TestContext(t), createdField.ID)) + resp, err := th.SystemAdminClient.DeleteCPAField(context.Background(), createdField.ID) + CheckOKStatus(t, resp) + require.NoError(t, err) + fields, resp, err := th.Client.ListCPAFields(context.Background()) CheckOKStatus(t, resp) require.NoError(t, err) @@ -171,11 +241,20 @@ func TestPatchCPAField(t *testing.T) { }) th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { - patch := &model.PropertyFieldPatch{Name: new(celSafeName())} - patchedField, resp, err := client.PatchCPAField(context.Background(), model.NewId(), patch) + // Create a field with a license so we can test the license check on patch. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + field := &model.PropertyField{Name: celSafeName(), Type: model.PropertyFieldTypeText} + createdField, _, createErr := th.SystemAdminClient.CreateCPAField(context.Background(), field) + require.NoError(t, createErr) + require.NotNil(t, createdField) + + // Remove the license and verify patch is blocked. + th.App.Srv().SetLicense(nil) + patch := &model.PropertyFieldPatch{Name: model.NewPointer(celSafeName())} + patchedField, resp, err := client.PatchCPAField(context.Background(), createdField.ID, patch) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, patchedField) }, "endpoint should not work if no valid license is present") @@ -183,18 +262,18 @@ func TestPatchCPAField(t *testing.T) { th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) t.Run("a user without admin permissions should not be able to patch a field", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) - patch := &model.PropertyFieldPatch{Name: new(celSafeName())} - _, resp, err := th.Client.PatchCPAField(context.Background(), createdField.ID, patch) + patch := &model.PropertyFieldPatch{Name: model.NewPointer(celSafeName())} + _, resp, err = th.Client.PatchCPAField(context.Background(), createdField.ID, patch) CheckForbiddenStatus(t, resp) require.Error(t, err) }) @@ -202,18 +281,18 @@ func TestPatchCPAField(t *testing.T) { th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { webSocketClient := th.CreateConnectedWebSocketClient(t) - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) newName := celSafeName() - patch := &model.PropertyFieldPatch{Name: new(fmt.Sprintf(" %s \t ", newName))} // name should be sanitized + patch := &model.PropertyFieldPatch{Name: model.NewPointer(fmt.Sprintf(" %s \t ", newName))} // name should be sanitized patchedField, resp, err := client.PatchCPAField(context.Background(), createdField.ID, patch) CheckOKStatus(t, resp) require.NoError(t, err) @@ -239,85 +318,22 @@ func TestPatchCPAField(t *testing.T) { require.NotEmpty(t, wsField.ID) require.Equal(t, patchedField, &wsField) }) - - t.Run("sanitization should remove options and sync details when necessary", func(t *testing.T) { - // Create a select field with options - optionID1 := model.NewId() - optionID2 := model.NewId() - selectField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: celSafeName(), - Type: model.PropertyFieldTypeSelect, - Attrs: model.StringInterface{ - "options": []map[string]any{ - {"id": optionID1, "name": "Option 1", "color": "#FF0000"}, - {"id": optionID2, "name": "Option 2", "color": "#00FF00"}, - }, - }, - }) - require.NoError(t, err) - - createdField, _, err := client.CreateCPAField(context.Background(), selectField.ToPropertyField()) - require.NoError(t, err) - require.NotNil(t, createdField) - - // Verify options were created - options, ok := createdField.Attrs["options"] - require.True(t, ok) - require.NotNil(t, options) - - // Patch to change type to text with LDAP attribute - // Options should be automatically removed even though we don't explicitly remove them - ldapAttr := "user_attribute" - textPatch := &model.PropertyFieldPatch{ - Type: model.NewPointer(model.PropertyFieldTypeText), - Attrs: &model.StringInterface{"ldap": ldapAttr}, - } - - patchedTextField, resp, err := client.PatchCPAField(context.Background(), createdField.ID, textPatch) - CheckOKStatus(t, resp) - require.NoError(t, err) - require.Equal(t, model.PropertyFieldTypeText, patchedTextField.Type) - - // Verify options were removed - options = patchedTextField.Attrs["options"] - require.Empty(t, options) - - // Verify LDAP attribute was set - ldap, ok := patchedTextField.Attrs["ldap"] - require.True(t, ok) - require.Equal(t, ldapAttr, ldap) - - // Now patch to change type to date - // LDAP attribute should be automatically removed even though we don't explicitly remove it - datePatch := &model.PropertyFieldPatch{ - Type: model.NewPointer(model.PropertyFieldTypeDate), - } - - patchedDateField, resp, err := client.PatchCPAField(context.Background(), patchedTextField.ID, datePatch) - CheckOKStatus(t, resp) - require.NoError(t, err) - require.Equal(t, model.PropertyFieldTypeDate, patchedDateField.Type) - - // Verify LDAP attribute was removed - ldap = patchedDateField.Attrs["ldap"] - require.Empty(t, ldap) - }) }, "a user with admin permissions should be able to patch the field") th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { // Create a regular field first - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) // Verify field is not isManaged initially - require.Empty(t, createdField.Attrs.Managed) + require.Empty(t, createdField.Attrs[model.CustomProfileAttributesPropertyAttrsManaged]) // Patch to make it managed managedPatch := &model.PropertyFieldPatch{ @@ -345,6 +361,171 @@ func TestPatchCPAField(t *testing.T) { // Verify managed attribute is removed or empty require.Empty(t, patchedUnmanagedField.Attrs[model.CustomProfileAttributesPropertyAttrsManaged]) }, "admin should be able to toggle managed attribute on existing field") + + t.Run("patching select options preserves existing option IDs and assigns new IDs to added options", func(t *testing.T) { + selectField := &model.PropertyField{ + Name: "select_field_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"name": "Option 1", "color": "#111111"}, + map[string]any{"name": "Option 2", "color": "#222222"}, + }, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), selectField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + + createdCPA, err := model.NewCPAFieldFromPropertyField(created) + require.NoError(t, err) + require.Len(t, createdCPA.Attrs.Options, 2) + id1 := createdCPA.Attrs.Options[0].ID + id2 := createdCPA.Attrs.Options[1].ID + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + + patch := &model.PropertyFieldPatch{ + Attrs: model.NewPointer(model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": id1, "name": "Updated Option 1", "color": "#333333"}, + map[string]any{"name": "New Option 1.5", "color": "#353535"}, + map[string]any{"id": id2, "name": "Updated Option 2", "color": "#444444"}, + }, + }), + } + patched, resp, err := th.SystemAdminClient.PatchCPAField(context.Background(), created.ID, patch) + CheckOKStatus(t, resp) + require.NoError(t, err) + + patchedCPA, err := model.NewCPAFieldFromPropertyField(patched) + require.NoError(t, err) + require.Len(t, patchedCPA.Attrs.Options, 3) + + require.Equal(t, id1, patchedCPA.Attrs.Options[0].ID) + require.Equal(t, "Updated Option 1", patchedCPA.Attrs.Options[0].Name) + require.Equal(t, "#333333", patchedCPA.Attrs.Options[0].Color) + require.NotEmpty(t, patchedCPA.Attrs.Options[1].ID) + require.Equal(t, "New Option 1.5", patchedCPA.Attrs.Options[1].Name) + require.Equal(t, id2, patchedCPA.Attrs.Options[2].ID) + require.Equal(t, "Updated Option 2", patchedCPA.Attrs.Options[2].Name) + }) + + t.Run("changing a field's type deletes dependent values and emits delete_values:true", func(t *testing.T) { + webSocketClient := th.CreateConnectedWebSocketClient(t) + + selectField := &model.PropertyField{ + Name: "select_type_change_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"name": "Option 1", "color": "#FF5733"}, + }, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), selectField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + + createdCPA, err := model.NewCPAFieldFromPropertyField(created) + require.NoError(t, err) + require.NotEmpty(t, createdCPA.Attrs.Options) + optionID := createdCPA.Attrs.Options[0].ID + require.NotEmpty(t, optionID) + + // Seed a value for BasicUser referencing the option. + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + created.ID: json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) + + // Patch type from select → text. + typePatch := &model.PropertyFieldPatch{Type: model.NewPointer(model.PropertyFieldTypeText)} + _, resp, err = th.SystemAdminClient.PatchCPAField(context.Background(), created.ID, typePatch) + CheckOKStatus(t, resp) + require.NoError(t, err) + + // The dependent value must be gone. + retrieved, resp, err := th.SystemAdminClient.ListCPAValues(context.Background(), th.BasicUser.Id) + CheckOKStatus(t, resp) + require.NoError(t, err) + _, present := retrieved[created.ID] + require.False(t, present, "value should be deleted when the field's type changes") + + // The legacy CPA WS event must carry delete_values:true. + var sawDeleteValues bool + require.Eventually(t, func() bool { + for { + select { + case event := <-webSocketClient.EventChannel: + if event.EventType() != model.WebsocketEventCPAFieldUpdated { + continue + } + if dv, ok := event.GetData()["delete_values"].(bool); ok && dv { + sawDeleteValues = true + return true + } + default: + return false + } + } + }, 5*time.Second, 100*time.Millisecond) + require.True(t, sawDeleteValues, "expected cpa_field_updated to carry delete_values:true on a type change") + }) + + t.Run("patching a field without changing its type preserves existing values", func(t *testing.T) { + selectField := &model.PropertyField{ + Name: "select_with_values_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"name": "Option 1", "color": "#FF5733"}, + map[string]any{"name": "Option 2", "color": "#33FF57"}, + }, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), selectField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + + createdCPA, err := model.NewCPAFieldFromPropertyField(created) + require.NoError(t, err) + require.NotEmpty(t, createdCPA.Attrs.Options) + optionID := createdCPA.Attrs.Options[0].ID + require.NotEmpty(t, optionID) + + // Admin writes a value on behalf of BasicUser. + values := map[string]json.RawMessage{ + created.ID: json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), + } + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) + CheckOKStatus(t, resp) + require.NoError(t, err) + + // Rename field + add an option, keeping Type unchanged. + patch := &model.PropertyFieldPatch{ + Name: model.NewPointer("renamed_" + model.NewId()), + Attrs: model.NewPointer(model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Renamed Option 1", "color": "#FF5733"}, + map[string]any{"name": "Option 2", "color": "#33FF57"}, + map[string]any{"name": "Option 3", "color": "#5733FF"}, + }, + }), + } + _, resp, err = th.SystemAdminClient.PatchCPAField(context.Background(), created.ID, patch) + CheckOKStatus(t, resp) + require.NoError(t, err) + + // BasicUser's value for this field should still be present. + retrieved, resp, err := th.SystemAdminClient.ListCPAValues(context.Background(), th.BasicUser.Id) + CheckOKStatus(t, resp) + require.NoError(t, err) + rawValue, ok := retrieved[created.ID] + require.True(t, ok, "value should still exist after a non-type-changing patch") + require.Equal(t, json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), rawValue) + }) } func TestDeleteCPAField(t *testing.T) { @@ -354,10 +535,19 @@ func TestDeleteCPAField(t *testing.T) { }) th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { - resp, err := client.DeleteCPAField(context.Background(), model.NewId()) + // Create a field with a license so we can test the license check on delete. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + field := &model.PropertyField{Name: celSafeName(), Type: model.PropertyFieldTypeText} + createdField, _, createErr := th.SystemAdminClient.CreateCPAField(context.Background(), field) + require.NoError(t, createErr) + require.NotNil(t, createdField) + + // Remove the license and verify delete is blocked. + th.App.Srv().SetLicense(nil) + resp, err := client.DeleteCPAField(context.Background(), createdField.ID) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") }, "endpoint should not work if no valid license is present") // add a valid license @@ -393,7 +583,12 @@ func TestDeleteCPAField(t *testing.T) { CheckOKStatus(t, resp) require.NoError(t, err) - deletedField, appErr := th.App.GetCPAField(request.TestContext(t), createdField.ID) + // The list endpoint filters out deleted fields, so read at the app layer + // to confirm the soft-delete landed on the record itself. + rctx := request.TestContext(t) + group, appErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, appErr) + deletedField, appErr := th.App.GetPropertyField(rctx, group.ID, createdField.ID) require.Nil(t, appErr) require.NotZero(t, deletedField.DeleteAt) @@ -426,33 +621,39 @@ func TestListCPAValues(t *testing.T) { cfg.FeatureFlags.CustomProfileAttributes = true }).InitBasic(t) + // License required for field/value creation (LicenseCheckHook) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + th.RemovePermissionFromRole(t, model.PermissionViewMembers.Id, model.SystemUserRoleId) defer th.AddPermissionToRole(t, model.PermissionViewMembers.Id, model.SystemUserRoleId) - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) - _, appErr = th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdField.ID, json.RawMessage(`"Field Value"`), true) - require.Nil(t, appErr) + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdField.ID: json.RawMessage(`"Field Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) t.Run("endpoint should not work if no valid license is present", func(t *testing.T) { + th.App.Srv().SetLicense(nil) + defer th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + values, resp, err := th.Client.ListCPAValues(context.Background(), th.BasicUser.Id) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, values) }) - // add a valid license - th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - // login with Client2 from this point on th.LoginBasic2(t) @@ -467,7 +668,7 @@ func TestListCPAValues(t *testing.T) { t.Run("should handle array values correctly", func(t *testing.T) { optionID1 := model.NewId() optionID2 := model.NewId() - arrayField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + arrayField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeMultiselect, Attrs: model.StringInterface{ @@ -476,15 +677,18 @@ func TestListCPAValues(t *testing.T) { {"id": optionID2, "name": "option2"}, }, }, - }) - require.NoError(t, err) + } - createdArrayField, appErr := th.App.CreateCPAField(request.TestContext(t), arrayField) - require.Nil(t, appErr) + createdArrayField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), arrayField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdArrayField) - _, appErr = th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdArrayField.ID, json.RawMessage(fmt.Sprintf(`["%s", "%s"]`, optionID1, optionID2)), true) - require.Nil(t, appErr) + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdArrayField.ID: json.RawMessage(fmt.Sprintf(`["%s", "%s"]`, optionID1, optionID2)), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) values, resp, err := th.Client.ListCPAValues(context.Background(), th.BasicUser.Id) CheckOKStatus(t, resp) @@ -514,28 +718,31 @@ func TestPatchCPAValues(t *testing.T) { cfg.FeatureFlags.CustomProfileAttributes = true }).InitBasic(t) - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + // License required for field creation (LicenseCheckHook) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) t.Run("endpoint should not work if no valid license is present", func(t *testing.T) { + th.App.Srv().SetLicense(nil) + defer th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + values := map[string]json.RawMessage{createdField.ID: json.RawMessage(`"Field Value"`)} patchedValues, resp, err := th.Client.PatchCPAValues(context.Background(), values) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, patchedValues) }) - // add a valid license - th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - t.Run("any team member should be able to create their own values", func(t *testing.T) { webSocketClient := th.CreateConnectedWebSocketClient(t) @@ -609,7 +816,7 @@ func TestPatchCPAValues(t *testing.T) { t.Run("should handle array values correctly", func(t *testing.T) { optionsID := []string{model.NewId(), model.NewId(), model.NewId(), model.NewId()} - arrayField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + arrayField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeMultiselect, Attrs: model.StringInterface{ @@ -620,11 +827,11 @@ func TestPatchCPAValues(t *testing.T) { {"id": optionsID[3], "name": "option4"}, }, }, - }) - require.NoError(t, err) + } - createdArrayField, appErr := th.App.CreateCPAField(request.TestContext(t), arrayField) - require.Nil(t, appErr) + createdArrayField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), arrayField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdArrayField) values := map[string]json.RawMessage{ @@ -652,50 +859,50 @@ func TestPatchCPAValues(t *testing.T) { t.Run("should fail if any of the values belongs to a field that is LDAP/SAML synced", func(t *testing.T) { // Create a field with LDAP attribute - ldapField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + ldapField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsLDAP: "ldap_attr", }, - }) - require.NoError(t, err) + } - createdLDAPField, appErr := th.App.CreateCPAField(request.TestContext(t), ldapField) - require.Nil(t, appErr) + createdLDAPField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), ldapField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdLDAPField) // Create a field with SAML attribute - samlField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + samlField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsSAML: "saml_attr", }, - }) - require.NoError(t, err) + } - createdSAMLField, appErr := th.App.CreateCPAField(request.TestContext(t), samlField) - require.Nil(t, appErr) + createdSAMLField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), samlField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdSAMLField) // Test LDAP field values := map[string]json.RawMessage{ createdLDAPField.ID: json.RawMessage(`"LDAP Value"`), } - _, resp, err := th.Client.PatchCPAValues(context.Background(), values) - CheckBadRequestStatus(t, resp) + _, resp, err = th.Client.PatchCPAValues(context.Background(), values) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") // Test SAML field values = map[string]json.RawMessage{ createdSAMLField.ID: json.RawMessage(`"SAML Value"`), } _, resp, err = th.Client.PatchCPAValues(context.Background(), values) - CheckBadRequestStatus(t, resp) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") // Test multiple fields with one being LDAP synced values = map[string]json.RawMessage{ @@ -703,20 +910,20 @@ func TestPatchCPAValues(t *testing.T) { createdLDAPField.ID: json.RawMessage(`"LDAP Value"`), } _, resp, err = th.Client.PatchCPAValues(context.Background(), values) - CheckBadRequestStatus(t, resp) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") }) t.Run("an invalid patch should be rejected", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) // Create a value that's too long (over 64 characters) @@ -725,16 +932,16 @@ func TestPatchCPAValues(t *testing.T) { createdField.ID: json.RawMessage(fmt.Sprintf(`"%s"`, tooLongValue)), } - _, resp, err := th.Client.PatchCPAValues(context.Background(), values) + _, resp, err = th.Client.PatchCPAValues(context.Background(), values) CheckBadRequestStatus(t, resp) require.Error(t, err) - require.Contains(t, err.Error(), "Failed to validate property value") + CheckErrorID(t, err, "app.property_value.validate.app_error") }) t.Run("admin-managed fields", func(t *testing.T) { // Create a managed field (only admins can create fields) managedField := &model.PropertyField{ - Name: "managed_field", + Name: "managed_field_" + model.NewId(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsManaged: "admin", @@ -748,7 +955,7 @@ func TestPatchCPAValues(t *testing.T) { // Create a non-managed field for comparison regularField := &model.PropertyField{ - Name: "regular_field", + Name: "regular_field_" + model.NewId(), Type: model.PropertyFieldTypeText, } @@ -765,7 +972,7 @@ func TestPatchCPAValues(t *testing.T) { _, resp, err := th.Client.PatchCPAValues(context.Background(), values) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_managed.app_error") + CheckErrorID(t, err, "api.property_value.patch.no_values_permission.app_error") }) t.Run("regular user can update non-managed field", func(t *testing.T) { @@ -799,13 +1006,18 @@ func TestPatchCPAValues(t *testing.T) { }) t.Run("batch update with managed fields fails for regular user", func(t *testing.T) { - // First set some initial values to ensure we can verify they don't change - // Set initial values for both fields using th.App (admins can set managed field values) - _, appErr := th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdRegularField.ID, json.RawMessage(`"Initial Regular Value"`), false) - require.Nil(t, appErr) + // Admin seeds initial values for both fields on BasicUser. + _, resp, err := th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdRegularField.ID: json.RawMessage(`"Initial Regular Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) - _, appErr = th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdManagedField.ID, json.RawMessage(`"Initial Managed Value"`), true) - require.Nil(t, appErr) + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdManagedField.ID: json.RawMessage(`"Initial Managed Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) // Try to batch update both managed and regular fields - this should fail attemptedValues := map[string]json.RawMessage{ @@ -813,43 +1025,21 @@ func TestPatchCPAValues(t *testing.T) { createdRegularField.ID: json.RawMessage(`"Regular Batch Value"`), } - _, resp, err := th.Client.PatchCPAValues(context.Background(), attemptedValues) + _, resp, err = th.Client.PatchCPAValues(context.Background(), attemptedValues) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_managed.app_error") + CheckErrorID(t, err, "api.property_value.patch.no_values_permission.app_error") - // Verify that no values were updated when the batch operation failed - currentValues, appErr := th.App.ListCPAValues(request.TestContext(t), th.BasicUser.Id) - require.Nil(t, appErr) + // Verify that no values were updated when the batch operation failed. + currentValues, resp, err := th.SystemAdminClient.ListCPAValues(context.Background(), th.BasicUser.Id) + CheckOKStatus(t, resp) + require.NoError(t, err) - // Check that values remain unchanged - both fields should retain their initial values - regularFieldHasOriginalValue := false - managedFieldHasOriginalValue := false - - for _, value := range currentValues { - if value.FieldID == createdManagedField.ID { - var currentValue string - require.NoError(t, json.Unmarshal(value.Value, ¤tValue)) - if currentValue == "Initial Managed Value" { - managedFieldHasOriginalValue = true - } - // Verify it's not the attempted update value - require.NotEqual(t, "Managed Batch Value", currentValue, "Managed field should not have been updated in failed batch operation") - } - if value.FieldID == createdRegularField.ID { - var currentValue string - require.NoError(t, json.Unmarshal(value.Value, ¤tValue)) - if currentValue == "Initial Regular Value" { - regularFieldHasOriginalValue = true - } - // Verify it's not the attempted update value - require.NotEqual(t, "Regular Batch Value", currentValue, "Regular field should not have been updated in failed batch operation") - } - } - - // Both fields should retain their original values after the failed batch operation - require.True(t, regularFieldHasOriginalValue, "Regular field should retain its original value") - require.True(t, managedFieldHasOriginalValue, "Managed field should retain its original value") + var managedValue, regularValue string + require.NoError(t, json.Unmarshal(currentValues[createdManagedField.ID], &managedValue)) + require.NoError(t, json.Unmarshal(currentValues[createdRegularField.ID], ®ularValue)) + require.Equal(t, "Initial Managed Value", managedValue, "Managed field should not have been updated in failed batch operation") + require.Equal(t, "Initial Regular Value", regularValue, "Regular field should not have been updated in failed batch operation") }) t.Run("batch update with managed fields succeeds for admin", func(t *testing.T) { @@ -870,6 +1060,59 @@ func TestPatchCPAValues(t *testing.T) { require.Equal(t, "Admin Regular Batch", regularValue) }) }) + + t.Run("patch fails if any field in the map does not exist", func(t *testing.T) { + // App.GetPropertyFields rejects an unknown id with a 404 before the + // handler's per-field 404 check runs. The property service wraps the + // store's ErrResultsMismatch with the ErrFieldNotFound sentinel, + // which mapPropertyServiceError translates into a not-found error. + values := map[string]json.RawMessage{ + model.NewId(): json.RawMessage(`"any value"`), + } + _, resp, err := th.Client.PatchCPAValues(context.Background(), values) + CheckNotFoundStatus(t, resp) + require.Error(t, err) + CheckErrorID(t, err, "app.property_field.not_found.app_error") + }) + + t.Run("rejects values that fail hook validation", func(t *testing.T) { + optionsID := []string{model.NewId(), model.NewId(), model.NewId()} + arrayField := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeMultiselect, + Attrs: model.StringInterface{ + "options": []map[string]any{ + {"id": optionsID[0], "name": "option1"}, + {"id": optionsID[1], "name": "option2"}, + {"id": optionsID[2], "name": "option3"}, + }, + }, + } + + createdArrayField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), arrayField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotNil(t, createdArrayField) + + t.Run("invalid option ID", func(t *testing.T) { + unknownOption := model.NewId() + values := map[string]json.RawMessage{ + createdArrayField.ID: json.RawMessage(fmt.Sprintf(`["%s", "%s"]`, optionsID[0], unknownOption)), + } + _, resp, err := th.Client.PatchCPAValues(context.Background(), values) + CheckBadRequestStatus(t, resp) + require.Error(t, err) + }) + + t.Run("wrong value type (string instead of array)", func(t *testing.T) { + values := map[string]json.RawMessage{ + createdArrayField.ID: json.RawMessage(`"not an array"`), + } + _, resp, err := th.Client.PatchCPAValues(context.Background(), values) + CheckBadRequestStatus(t, resp) + require.Error(t, err) + }) + }) } func TestPatchCPAValuesForUser(t *testing.T) { @@ -879,22 +1122,28 @@ func TestPatchCPAValuesForUser(t *testing.T) { cfg.FeatureFlags.CustomProfileAttributes = true }).InitBasic(t) - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + // License required for field creation (LicenseCheckHook) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) t.Run("endpoint should not work if no valid license is present", func(t *testing.T) { + th.App.Srv().SetLicense(nil) + defer th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + values := map[string]json.RawMessage{createdField.ID: json.RawMessage(`"Field Value"`)} patchedValues, resp, err := th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "api.custom_profile_attributes.license_error") + CheckErrorID(t, err, "app.property.license_error") require.Empty(t, patchedValues) }) @@ -974,7 +1223,7 @@ func TestPatchCPAValuesForUser(t *testing.T) { t.Run("should handle array values correctly", func(t *testing.T) { optionsID := []string{model.NewId(), model.NewId(), model.NewId(), model.NewId()} - arrayField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + arrayField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeMultiselect, Attrs: model.StringInterface{ @@ -985,11 +1234,11 @@ func TestPatchCPAValuesForUser(t *testing.T) { {"id": optionsID[3], "name": "option4"}, }, }, - }) - require.NoError(t, err) + } - createdArrayField, appErr := th.App.CreateCPAField(request.TestContext(t), arrayField) - require.Nil(t, appErr) + createdArrayField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), arrayField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdArrayField) values := map[string]json.RawMessage{ @@ -1017,50 +1266,50 @@ func TestPatchCPAValuesForUser(t *testing.T) { t.Run("should fail if any of the values belongs to a field that is LDAP/SAML synced", func(t *testing.T) { // Create a field with LDAP attribute - ldapField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + ldapField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsLDAP: "ldap_attr", }, - }) - require.NoError(t, err) + } - createdLDAPField, appErr := th.App.CreateCPAField(request.TestContext(t), ldapField) - require.Nil(t, appErr) + createdLDAPField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), ldapField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdLDAPField) // Create a field with SAML attribute - samlField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + samlField := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsSAML: "saml_attr", }, - }) - require.NoError(t, err) + } - createdSAMLField, appErr := th.App.CreateCPAField(request.TestContext(t), samlField) - require.Nil(t, appErr) + createdSAMLField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), samlField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdSAMLField) // Test LDAP field values := map[string]json.RawMessage{ createdLDAPField.ID: json.RawMessage(`"LDAP Value"`), } - _, resp, err := th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) - CheckBadRequestStatus(t, resp) + _, resp, err = th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") // Test SAML field values = map[string]json.RawMessage{ createdSAMLField.ID: json.RawMessage(`"SAML Value"`), } _, resp, err = th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) - CheckBadRequestStatus(t, resp) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") // Test multiple fields with one being LDAP synced values = map[string]json.RawMessage{ @@ -1068,20 +1317,20 @@ func TestPatchCPAValuesForUser(t *testing.T) { createdLDAPField.ID: json.RawMessage(`"LDAP Value"`), } _, resp, err = th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) - CheckBadRequestStatus(t, resp) + CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_synced.app_error") + CheckErrorID(t, err, "app.property.sync_lock.app_error") }) t.Run("an invalid patch should be rejected", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ + field := &model.PropertyField{ Name: celSafeName(), Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) + } - createdField, appErr := th.App.CreateCPAField(request.TestContext(t), field) - require.Nil(t, appErr) + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) require.NotNil(t, createdField) // Create a value that's too long (over 64 characters) @@ -1090,16 +1339,16 @@ func TestPatchCPAValuesForUser(t *testing.T) { createdField.ID: json.RawMessage(fmt.Sprintf(`"%s"`, tooLongValue)), } - _, resp, err := th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) + _, resp, err = th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) CheckBadRequestStatus(t, resp) require.Error(t, err) - require.Contains(t, err.Error(), "Failed to validate property value") + CheckErrorID(t, err, "app.property_value.validate.app_error") }) t.Run("admin-managed fields", func(t *testing.T) { // Create a managed field (only admins can create fields) managedField := &model.PropertyField{ - Name: "managed_field_v2", + Name: "managed_field_" + model.NewId(), Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{ model.CustomProfileAttributesPropertyAttrsManaged: "admin", @@ -1113,7 +1362,7 @@ func TestPatchCPAValuesForUser(t *testing.T) { // Create a non-managed field for comparison regularField := &model.PropertyField{ - Name: "regular_field_v2", + Name: "regular_field_" + model.NewId(), Type: model.PropertyFieldTypeText, } @@ -1130,7 +1379,7 @@ func TestPatchCPAValuesForUser(t *testing.T) { _, resp, err := th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, values) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_managed.app_error") + CheckErrorID(t, err, "api.property_value.patch.no_values_permission.app_error") }) t.Run("regular user can update non-managed field", func(t *testing.T) { @@ -1149,9 +1398,12 @@ func TestPatchCPAValuesForUser(t *testing.T) { }) th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { - // Set initial value through the app layer that we will be replacing during the test - _, appErr := th.App.PatchCPAValue(request.TestContext(t), th.SystemAdminUser.Id, createdManagedField.ID, json.RawMessage(`"Initial Admin Value"`), true) - require.Nil(t, appErr) + // Seed a baseline value that the test run then replaces. + _, resp, err := th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.SystemAdminUser.Id, map[string]json.RawMessage{ + createdManagedField.ID: json.RawMessage(`"Initial Admin Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) values := map[string]json.RawMessage{ createdManagedField.ID: json.RawMessage(`"Admin Updated Value"`), @@ -1205,13 +1457,18 @@ func TestPatchCPAValuesForUser(t *testing.T) { }) t.Run("batch update with managed fields fails for regular user", func(t *testing.T) { - // First set some initial values to ensure we can verify they don't change - // Set initial values for both fields using th.App (admins can set managed field values) - _, appErr := th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdRegularField.ID, json.RawMessage(`"Initial Regular Value"`), false) - require.Nil(t, appErr) + // Admin seeds initial values for both fields on BasicUser. + _, resp, err := th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdRegularField.ID: json.RawMessage(`"Initial Regular Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) - _, appErr = th.App.PatchCPAValue(request.TestContext(t), th.BasicUser.Id, createdManagedField.ID, json.RawMessage(`"Initial Managed Value"`), true) - require.Nil(t, appErr) + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, map[string]json.RawMessage{ + createdManagedField.ID: json.RawMessage(`"Initial Managed Value"`), + }) + CheckOKStatus(t, resp) + require.NoError(t, err) // Try to batch update both managed and regular fields - this should fail attemptedValues := map[string]json.RawMessage{ @@ -1219,43 +1476,21 @@ func TestPatchCPAValuesForUser(t *testing.T) { createdRegularField.ID: json.RawMessage(`"Regular Batch Value"`), } - _, resp, err := th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, attemptedValues) + _, resp, err = th.Client.PatchCPAValuesForUser(context.Background(), th.BasicUser.Id, attemptedValues) CheckForbiddenStatus(t, resp) require.Error(t, err) - CheckErrorID(t, err, "app.custom_profile_attributes.property_field_is_managed.app_error") + CheckErrorID(t, err, "api.property_value.patch.no_values_permission.app_error") - // Verify that no values were updated when the batch operation failed - currentValues, appErr := th.App.ListCPAValues(request.TestContext(t), th.BasicUser.Id) - require.Nil(t, appErr) + // Verify that no values were updated when the batch operation failed. + currentValues, resp, err := th.SystemAdminClient.ListCPAValues(context.Background(), th.BasicUser.Id) + CheckOKStatus(t, resp) + require.NoError(t, err) - // Check that values remain unchanged - both fields should retain their initial values - regularFieldHasOriginalValue := false - managedFieldHasOriginalValue := false - - for _, value := range currentValues { - if value.FieldID == createdManagedField.ID { - var currentValue string - require.NoError(t, json.Unmarshal(value.Value, ¤tValue)) - if currentValue == "Initial Managed Value" { - managedFieldHasOriginalValue = true - } - // Verify it's not the attempted update value - require.NotEqual(t, "Managed Batch Value", currentValue, "Managed field should not have been updated in failed batch operation") - } - if value.FieldID == createdRegularField.ID { - var currentValue string - require.NoError(t, json.Unmarshal(value.Value, ¤tValue)) - if currentValue == "Initial Regular Value" { - regularFieldHasOriginalValue = true - } - // Verify it's not the attempted update value - require.NotEqual(t, "Regular Batch Value", currentValue, "Regular field should not have been updated in failed batch operation") - } - } - - // Both fields should retain their original values after the failed batch operation - require.True(t, regularFieldHasOriginalValue, "Regular field should retain its original value") - require.True(t, managedFieldHasOriginalValue, "Managed field should retain its original value") + var managedValue, regularValue string + require.NoError(t, json.Unmarshal(currentValues[createdManagedField.ID], &managedValue)) + require.NoError(t, json.Unmarshal(currentValues[createdRegularField.ID], ®ularValue)) + require.Equal(t, "Initial Managed Value", managedValue, "Managed field should not have been updated in failed batch operation") + require.Equal(t, "Initial Regular Value", regularValue, "Regular field should not have been updated in failed batch operation") }) th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { @@ -1277,3 +1512,346 @@ func TestPatchCPAValuesForUser(t *testing.T) { }, "batch update with managed fields succeeds for admin") }) } + +// TestCPANonAdminWriteOwnValueViaGenericAPI confirms a non-admin user can set +// their own value on a regular CPA field via the generic property API. +func TestCPANonAdminWriteOwnValueViaGenericAPI(t *testing.T) { + mainHelper.Parallel(t) + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.CustomProfileAttributes = true + cfg.FeatureFlags.IntegratedBoards = true + }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + } + createdField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotNil(t, createdField) + + value := "Self Value" + items := []model.PropertyValuePatchItem{{ + FieldID: createdField.ID, + Value: json.RawMessage(fmt.Sprintf(`%q`, value)), + }} + + upserted, resp, err := th.Client.PatchPropertyValues( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + th.BasicUser.Id, + items, + ) + CheckOKStatus(t, resp) + require.NoError(t, err) + require.Len(t, upserted, 1) + require.Equal(t, createdField.ID, upserted[0].FieldID) + require.Equal(t, th.BasicUser.Id, upserted[0].TargetID) + require.Equal(t, model.PropertyValueTargetTypeUser, upserted[0].TargetType) + + var actualValue string + require.NoError(t, json.Unmarshal(upserted[0].Value, &actualValue)) + require.Equal(t, value, actualValue) + + // Verify the write persisted via a generic-API read on the same target. + stored, resp, err := th.Client.GetPropertyValues( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + th.BasicUser.Id, + model.PropertyValueSearch{PerPage: 60}, + ) + CheckOKStatus(t, resp) + require.NoError(t, err) + require.Len(t, stored, 1) + require.Equal(t, createdField.ID, stored[0].FieldID) + + var readValue string + require.NoError(t, json.Unmarshal(stored[0].Value, &readValue)) + require.Equal(t, value, readValue) +} + +// TestCPANonAdminBlockedFromAdminManagedViaGenericAPI confirms a non-admin user +// is blocked from writing their own value on an admin-only CPA field via the +// generic property API. +func TestCPANonAdminBlockedFromAdminManagedViaGenericAPI(t *testing.T) { + mainHelper.Parallel(t) + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.CustomProfileAttributes = true + cfg.FeatureFlags.IntegratedBoards = true + }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + managedField := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + createdManagedField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), managedField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotNil(t, createdManagedField) + + items := []model.PropertyValuePatchItem{{ + FieldID: createdManagedField.ID, + Value: json.RawMessage(`"Non-Admin Value"`), + }} + + t.Run("non-admin writing own admin-managed value is forbidden", func(t *testing.T) { + _, resp, err := th.Client.PatchPropertyValues( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + th.BasicUser.Id, + items, + ) + CheckForbiddenStatus(t, resp) + require.Error(t, err) + CheckErrorID(t, err, "api.property_value.patch.no_values_permission.app_error") + }) + + t.Run("admin writing same admin-managed value succeeds", func(t *testing.T) { + adminItems := []model.PropertyValuePatchItem{{ + FieldID: createdManagedField.ID, + Value: json.RawMessage(`"Admin Value"`), + }} + upserted, resp, err := th.SystemAdminClient.PatchPropertyValues( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + th.BasicUser.Id, + adminItems, + ) + CheckOKStatus(t, resp) + require.NoError(t, err) + require.Len(t, upserted, 1) + + var actualValue string + require.NoError(t, json.Unmarshal(upserted[0].Value, &actualValue)) + require.Equal(t, "Admin Value", actualValue) + }) +} + +// TestCPACrossAPIFieldRoundtrip verifies that a CPA field created via one +// API surface reads back equivalently from the other. We deliberately do +// not do a full map-equality on Attrs: ToPropertyField packs empty-string +// defaults for every CPA-known key, so CPA→generic→CPA is lossy at the +// map level. Compare the explicit set of fields that should match instead. +func TestCPACrossAPIFieldRoundtrip(t *testing.T) { + mainHelper.Parallel(t) + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.CustomProfileAttributes = true + cfg.FeatureFlags.IntegratedBoards = true + }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + t.Run("create via CPA API, read via generic API", func(t *testing.T) { + name := celSafeName() + field := &model.PropertyField{ + Name: name, + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsValueType: model.CustomProfileAttributesValueTypeEmail, + model.CustomProfileAttributesPropertyAttrsSortOrder: 5, + model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityWhenSet, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotNil(t, created) + + listed, resp, err := th.SystemAdminClient.GetPropertyFields( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + model.PropertyFieldSearch{ + TargetType: string(model.PropertyFieldTargetLevelSystem), + PerPage: 60, + }, + ) + CheckOKStatus(t, resp) + require.NoError(t, err) + + var found *model.PropertyField + for _, pf := range listed { + if pf.ID == created.ID { + found = pf + break + } + } + require.NotNil(t, found, "field created via CPA API should be readable via generic API") + + require.Equal(t, created.ID, found.ID) + require.Equal(t, name, found.Name) + require.Equal(t, created.Type, found.Type) + require.Equal(t, created.GroupID, found.GroupID) + require.Equal(t, model.PropertyFieldObjectTypeUser, found.ObjectType) + require.Equal(t, string(model.PropertyFieldTargetLevelSystem), found.TargetType) + require.Empty(t, found.TargetID) + require.Equal(t, created.CreatedBy, found.CreatedBy) + require.Equal(t, created.CreateAt, found.CreateAt) + require.Equal(t, int64(0), found.DeleteAt) + require.Equal(t, created.PermissionField, found.PermissionField) + require.Equal(t, created.PermissionValues, found.PermissionValues) + require.Equal(t, created.PermissionOptions, found.PermissionOptions) + + require.Equal(t, model.CustomProfileAttributesValueTypeEmail, found.Attrs[model.CustomProfileAttributesPropertyAttrsValueType]) + require.EqualValues(t, 5, found.Attrs[model.CustomProfileAttributesPropertyAttrsSortOrder]) + require.Equal(t, model.CustomProfileAttributesVisibilityWhenSet, found.Attrs[model.CustomProfileAttributesPropertyAttrsVisibility]) + }) + + t.Run("create via generic API, read via CPA API", func(t *testing.T) { + name := celSafeName() + field := &model.PropertyField{ + Name: name, + Type: model.PropertyFieldTypeText, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsSortOrder: 3, + model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityAlways, + }, + } + created, resp, err := th.SystemAdminClient.CreatePropertyField( + context.Background(), + model.AccessControlPropertyGroupName, + model.PropertyFieldObjectTypeUser, + field, + ) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + require.NotNil(t, created) + + listed, resp, err := th.SystemAdminClient.ListCPAFields(context.Background()) + CheckOKStatus(t, resp) + require.NoError(t, err) + + var found *model.PropertyField + for _, pf := range listed { + if pf.ID == created.ID { + found = pf + break + } + } + require.NotNil(t, found, "field created via generic API should be readable via CPA ListCPAFields") + + require.Equal(t, created.ID, found.ID) + require.Equal(t, name, found.Name) + require.Equal(t, created.Type, found.Type) + require.Equal(t, created.GroupID, found.GroupID) + require.Equal(t, created.CreateAt, found.CreateAt) + require.Equal(t, int64(0), found.DeleteAt) + + // The CPA list response is CPAField-shaped: unmarshal to confirm + // the typed attrs struct round-trips cleanly. + cpaField, err := model.NewCPAFieldFromPropertyField(found) + require.NoError(t, err) + require.EqualValues(t, 3, cpaField.Attrs.SortOrder) + require.Equal(t, model.CustomProfileAttributesVisibilityAlways, cpaField.Attrs.Visibility) + }) +} + +// TestCPABackwardCompatAfterRefactor spot-checks invariants that could have +// drifted in the Phase 7 refactor of the CPA handlers into thin shims. Broad +// behavioral equivalence is already covered by the existing CPA tests (they +// still pass); these subtests target invariants that those tests don't +// exercise directly. +func TestCPABackwardCompatAfterRefactor(t *testing.T) { + mainHelper.Parallel(t) + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.CustomProfileAttributes = true + cfg.FeatureFlags.IntegratedBoards = true + }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + t.Run("ListCPAFields preserves sort_order ordering", func(t *testing.T) { + // Create in a non-sorted order; ListCPAFields should return them + // sorted ascending by sort_order via CPAFieldsFromPropertyFields. + ids := make([]string, 3) + for _, order := range []int{2, 0, 1} { + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsSortOrder: order, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + ids[order] = created.ID + } + + listed, resp, err := th.SystemAdminClient.ListCPAFields(context.Background()) + CheckOKStatus(t, resp) + require.NoError(t, err) + require.GreaterOrEqual(t, len(listed), 3) + + // Extract the three fields we just created, preserving ListCPAFields + // return order, and verify they match ids[0], ids[1], ids[2]. + var observed []string + for _, pf := range listed { + for _, expected := range ids { + if pf.ID == expected { + observed = append(observed, pf.ID) + } + } + } + require.Equal(t, ids, observed, "ListCPAFields must return fields in ascending sort_order") + }) + + t.Run("CPA create response has typed CPAField attrs, with defaults filled", func(t *testing.T) { + field := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsValueType: model.CustomProfileAttributesValueTypeEmail, + }, + } + created, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), field) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + + // The CPA response goes through ToPropertyField on the server side, + // so every CPA-known attrs key is present — including defaults like + // Visibility="when_set" that the caller did not send. + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsValueType) + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsVisibility) + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsSortOrder) + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsLDAP) + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsSAML) + require.Contains(t, created.Attrs, model.CustomProfileAttributesPropertyAttrsManaged) + + cpaField, err := model.NewCPAFieldFromPropertyField(created) + require.NoError(t, err) + require.Equal(t, model.CustomProfileAttributesValueTypeEmail, cpaField.Attrs.ValueType) + require.Equal(t, model.CustomProfileAttributesVisibilityWhenSet, cpaField.Attrs.Visibility) + }) + + t.Run("AccessControlHook still blocks LDAP-synced writes via CPA path", func(t *testing.T) { + ldapField := &model.PropertyField{ + Name: celSafeName(), + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsLDAP: "ldap_attr", + }, + } + createdLDAPField, resp, err := th.SystemAdminClient.CreateCPAField(context.Background(), ldapField) + CheckCreatedStatus(t, resp) + require.NoError(t, err) + + _, resp, err = th.SystemAdminClient.PatchCPAValuesForUser( + context.Background(), + th.BasicUser.Id, + map[string]json.RawMessage{createdLDAPField.ID: json.RawMessage(`"attempted write"`)}, + ) + CheckForbiddenStatus(t, resp) + require.Error(t, err) + CheckErrorID(t, err, "app.property.sync_lock.app_error") + }) +} diff --git a/server/channels/api4/properties.go b/server/channels/api4/properties.go index 5d647bb792b..a5e3e521e08 100644 --- a/server/channels/api4/properties.go +++ b/server/channels/api4/properties.go @@ -6,12 +6,14 @@ package api4 import ( "encoding/json" "errors" + "maps" "net/http" "strconv" "strings" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/channels/app" ) const maxPropertyValuePatchItems = 50 @@ -63,43 +65,33 @@ func createPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { return } - auditRec := c.MakeAuditRecord(model.AuditEventCreatePropertyField, model.AuditStatusFail) - defer c.LogAuditRec(auditRec) - - // Set ObjectType and GroupID from URL field.ObjectType = c.Params.ObjectType field.GroupID = group.ID - // System-object fields attach to the system itself; canonicalize the - // target fields so clients can't submit inconsistent combinations. - // Permissions are likewise pinned to sysadmin: a system field's - // TargetType is "system", which makes member-level scope checks resolve - // to "any authenticated user" (see hasPropertyFieldScopeAccess), so - // honouring a member-level permission on a system field would expose - // the field's definition, options, and values to every logged-in user. - if field.ObjectType == model.PropertyFieldObjectTypeSystem { - field.TargetType = string(model.PropertyFieldTargetLevelSystem) - field.TargetID = "" - sysadmin := model.PermissionLevelSysadmin - field.PermissionField = &sysadmin - field.PermissionValues = &sysadmin - field.PermissionOptions = &sysadmin - } + auditRec := c.MakeAuditRecord(model.AuditEventCreatePropertyField, model.AuditStatusFail) + defer c.LogAuditRec(auditRec) + model.AddEventParameterAuditableToAuditRec(auditRec, "property_field", field) + + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) - // Reject protected field creation via API if field.Protected { c.Err = model.NewAppError("createPropertyField", "api.property_field.create.protected_via_api.app_error", nil, "", http.StatusBadRequest) return } - // Template creation is always sysadmin-only, regardless of target_type. + // Pre-canonicalize system objects so the scope check below cannot be + // bypassed by submitting ObjectType=system with TargetType=channel. The + // App layer re-canonicalizes defensively for plugin/internal callers. + app.CanonicalizeSystemObjectField(field) + + // Templates are always sysadmin-only, regardless of TargetType. if field.ObjectType == model.PropertyFieldObjectTypeTemplate && !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) { c.SetPermissionError(model.PermissionManageSystem) return } - // Check scope access for creation based on target_type + // Scope-based create permission. switch field.TargetType { case "channel": if field.TargetID == "" { @@ -130,27 +122,15 @@ func createPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { return } - // Trim whitespace from name - field.Name = strings.TrimSpace(field.Name) - - // Set permissions based on admin status. - // Permissions are not accepted from the request body; they're set by the server. - // Templates default to sysadmin since they define the schema linked fields inherit. - // System-object fields likewise default to sysadmin since they attach to the - // Mattermost instance and only a system administrator should write them. + // Default permission levels: pin all three for non-admins, nil-fill for + // admins. Stays in API because it is session-bound. isAdmin := c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) - defaultLevel := model.PermissionLevelMember - if field.ObjectType == model.PropertyFieldObjectTypeTemplate || - field.ObjectType == model.PropertyFieldObjectTypeSystem { - defaultLevel = model.PermissionLevelSysadmin - } + defaultLevel := app.DefaultPropertyFieldPermissionLevel(field) if !isAdmin { - // Non-admin: force all permissions to the default level field.PermissionField = &defaultLevel field.PermissionValues = &defaultLevel field.PermissionOptions = &defaultLevel } else { - // Admin with nil fields: set defaults if field.PermissionField == nil { field.PermissionField = &defaultLevel } @@ -162,17 +142,13 @@ func createPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { } } - // Set creator field.CreatedBy = c.AppContext.Session().UserId field.UpdatedBy = c.AppContext.Session().UserId - - model.AddEventParameterAuditableToAuditRec(auditRec, "property_field", field) - connectionID := r.Header.Get(model.ConnectionId) - createdField, err := c.App.CreatePropertyField(c.AppContext, field, false, connectionID) - if err != nil { - c.Err = err + createdField, appErr := c.App.CreatePropertyField(rctx, field, false, connectionID) + if appErr != nil { + c.Err = appErr return } @@ -289,7 +265,6 @@ func patchPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { if c.Err != nil { return } - groupID := group.ID var patch *model.PropertyFieldPatch if err := json.NewDecoder(r.Body).Decode(&patch); err != nil || patch == nil { @@ -301,8 +276,6 @@ func patchPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { *patch.Name = strings.TrimSpace(*patch.Name) } - // target_id and target_type are identity fields that define the - // property's scope and cannot be modified via patch patch.TargetID = nil patch.TargetType = nil @@ -316,94 +289,65 @@ func patchPropertyField(c *Context, w http.ResponseWriter, r *http.Request) { return } - // Get existing field - existingField, err := c.App.GetPropertyField(c.AppContext, groupID, c.Params.FieldId) - if err != nil { - c.Err = err + auditRec := c.MakeAuditRecord(model.AuditEventPatchPropertyField, model.AuditStatusFail) + defer c.LogAuditRec(auditRec) + model.AddEventParameterAuditableToAuditRec(auditRec, "property_field_patch", patch) + + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + + existingField, appErr := c.App.GetPropertyField(rctx, group.ID, c.Params.FieldId) + if appErr != nil { + c.Err = appErr return } - // FIXME: IsPSAv1 currently includes template fields (ObjectType="template"), but - // templates are valid PSAv2 objects and must be patchable. Once the FIXME in - // model.PropertyField.IsPSAv1 is resolved, this extra condition can be removed. - if existingField.IsPSAv1() && existingField.ObjectType == "" { + // PSAv2 routes only operate on PSAv2 fields. Reject legacy fields. + if existingField.IsPSAv1() { c.Err = model.NewAppError("patchPropertyField", "api.property_field.patch.legacy_field.app_error", nil, "", http.StatusBadRequest) return } - // Verify ObjectType matches + // HTTP-routing: a 404 indistinguishable from "no such field" lets us + // bucket fields by URL ObjectType without leaking cross-bucket existence. if existingField.ObjectType != c.Params.ObjectType { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusBadRequest) + c.Err = model.NewAppError("patchPropertyField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) return } - auditRec := c.MakeAuditRecord(model.AuditEventPatchPropertyField, model.AuditStatusFail) - defer c.LogAuditRec(auditRec) - model.AddEventParameterAuditableToAuditRec(auditRec, "property_field_patch", patch) - auditRec.AddEventPriorState(existingField) - - // Reject update of protected field - if existingField.Protected { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.update.protected_via_api.app_error", nil, "", http.StatusForbidden) - return + // Permission branching (session-bound): options-only patches use a + // narrower permission than full edits. + isOptionsOnly := isOptionsOnlyPatch(patch) + if isOptionsOnly && existingField.Type != model.PropertyFieldTypeSelect && existingField.Type != model.PropertyFieldTypeMultiselect { + isOptionsOnly = false } - - // Linked field restrictions - if existingField.LinkedFieldID != nil && *existingField.LinkedFieldID != "" { - if patch.Type != nil { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.patch.linked_type_change.app_error", nil, "cannot modify type of a linked field", http.StatusBadRequest) - return - } - if patch.Attrs != nil { - if _, hasOpts := (*patch.Attrs)[model.PropertyFieldAttributeOptions]; hasOpts { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.patch.linked_options_change.app_error", nil, "cannot modify options of a linked field", http.StatusBadRequest) - return - } - } - // LinkedFieldID patch validation: only allow unlink (empty string) or same value (no-op) - if patch.LinkedFieldID != nil && *patch.LinkedFieldID != "" && *patch.LinkedFieldID != *existingField.LinkedFieldID { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.patch.linked_field_change.app_error", nil, "cannot change link target; unlink first then create a new linked field", http.StatusBadRequest) - return - } - } else { - // Field is not linked — reject attempts to set a new LinkedFieldID - if patch.LinkedFieldID != nil && *patch.LinkedFieldID != "" { - c.Err = model.NewAppError("patchPropertyField", "api.property_field.patch.cannot_link_existing.app_error", nil, "linked_field_id can only be set at creation time", http.StatusBadRequest) - return - } - } - - // Detect if this is an options-only update - isOptionsOnlyUpdate := isOptionsOnlyPatch(patch) - - // Options-only permission path only applies to select/multiselect fields. - // For other field types, treat options changes as a field update. - if isOptionsOnlyUpdate && existingField.Type != model.PropertyFieldTypeSelect && existingField.Type != model.PropertyFieldTypeMultiselect { - isOptionsOnlyUpdate = false - } - - // Check permissions - if isOptionsOnlyUpdate { - if !c.App.SessionHasPermissionToManagePropertyFieldOptions(c.AppContext, *c.AppContext.Session(), existingField) { + if isOptionsOnly { + if !c.App.SessionHasPermissionToManagePropertyFieldOptions(rctx, *c.AppContext.Session(), existingField) { c.Err = model.NewAppError("patchPropertyField", "api.property_field.update.no_options_permission.app_error", nil, "", http.StatusForbidden) return } } else { - if !c.App.SessionHasPermissionToEditPropertyField(c.AppContext, *c.AppContext.Session(), existingField) { + if !c.App.SessionHasPermissionToEditPropertyField(rctx, *c.AppContext.Session(), existingField) { c.Err = model.NewAppError("patchPropertyField", "api.property_field.update.no_field_permission.app_error", nil, "", http.StatusForbidden) return } } - // Apply patch + // Capture original state for audit before the in-place patch. Attrs is + // shallow-copied because Patch mutates it. + orig := *existingField + if existingField.Attrs != nil { + orig.Attrs = make(model.StringInterface, len(existingField.Attrs)) + maps.Copy(orig.Attrs, existingField.Attrs) + } + auditRec.AddEventPriorState(&orig) + existingField.Patch(patch, true) existingField.UpdatedBy = c.AppContext.Session().UserId - connectionID := r.Header.Get(model.ConnectionId) - updatedField, err := c.App.UpdatePropertyField(c.AppContext, groupID, existingField, false, connectionID) - if err != nil { - c.Err = err + updatedField, _, updateErr := c.App.UpdatePropertyField(rctx, group.ID, existingField, false, connectionID) + if updateErr != nil { + c.Err = updateErr return } @@ -426,42 +370,34 @@ func deletePropertyField(c *Context, w http.ResponseWriter, r *http.Request) { if c.Err != nil { return } - groupID := group.ID - - // Get existing field - existingField, err := c.App.GetPropertyField(c.AppContext, groupID, c.Params.FieldId) - if err != nil { - c.Err = err - return - } - - // Verify ObjectType matches - if existingField.ObjectType != c.Params.ObjectType { - c.Err = model.NewAppError("deletePropertyField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusBadRequest) - return - } auditRec := c.MakeAuditRecord(model.AuditEventDeletePropertyField, model.AuditStatusFail) defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "field_id", c.Params.FieldId) - auditRec.AddEventPriorState(existingField) - // Reject deletion of protected field - if existingField.Protected { - c.Err = model.NewAppError("deletePropertyField", "api.property_field.delete.protected_via_api.app_error", nil, "", http.StatusForbidden) + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + + existingField, appErr := c.App.GetPropertyField(rctx, group.ID, c.Params.FieldId) + if appErr != nil { + c.Err = appErr return } - // Check field edit permission - if !c.App.SessionHasPermissionToEditPropertyField(c.AppContext, *c.AppContext.Session(), existingField) { + if existingField.ObjectType != c.Params.ObjectType { + c.Err = model.NewAppError("deletePropertyField", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) + return + } + + if !c.App.SessionHasPermissionToEditPropertyField(rctx, *c.AppContext.Session(), existingField) { c.Err = model.NewAppError("deletePropertyField", "api.property_field.delete.no_permission.app_error", nil, "", http.StatusForbidden) return } - connectionID := r.Header.Get(model.ConnectionId) + auditRec.AddEventPriorState(existingField) - if err := c.App.DeletePropertyField(c.AppContext, groupID, c.Params.FieldId, false, connectionID); err != nil { - c.Err = err + connectionID := r.Header.Get(model.ConnectionId) + if deleteErr := c.App.DeletePropertyField(rctx, group.ID, c.Params.FieldId, false, connectionID); deleteErr != nil { + c.Err = deleteErr return } @@ -594,12 +530,6 @@ func patchPropertyValuesCore(c *Context, w http.ResponseWriter, r *http.Request, if c.Err != nil { return } - groupID := group.ID - - // Check target access based on object type - if !hasTargetAccess(c, objectType, targetID, true) { - return - } var items []model.PropertyValuePatchItem if err := json.NewDecoder(r.Body).Decode(&items); err != nil { @@ -607,104 +537,91 @@ func patchPropertyValuesCore(c *Context, w http.ResponseWriter, r *http.Request, return } - if len(items) == 0 { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.empty_body.app_error", nil, "", http.StatusBadRequest) - return - } - - if len(items) > maxPropertyValuePatchItems { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.too_many_items.request_error", map[string]any{ - "Max": maxPropertyValuePatchItems, - }, "", http.StatusBadRequest) - return - } - - // Collect and validate field IDs - idMap := map[string]bool{} - fieldIDs := make([]string, 0, len(items)) - for _, item := range items { - if !model.IsValidId(item.FieldID) { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.invalid_field_id.app_error", nil, "", http.StatusBadRequest) - return - } - if idMap[item.FieldID] { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.duplicate_field_id.app_error", nil, "", http.StatusBadRequest) - return - } - idMap[item.FieldID] = true - fieldIDs = append(fieldIDs, item.FieldID) - } - - // Load all fields and verify they belong to this group. - // GetPropertyFields scopes the lookup by groupID, so fields from - // a different group won't be found, causing a mismatch error. - fields, err := c.App.GetPropertyFields(c.AppContext, groupID, fieldIDs) - if err != nil { - c.Err = err - return - } - - // Each field's ObjectType must match the route's objectType. Without - // this, a caller could reference a field of one type via another - // type's route (e.g. a system field via the user-values route), - // bypassing the route-level access checks and persisting values whose - // TargetType disagrees with field.ObjectType. Templates are always - // rejected because objectType is never "template" on these routes; - // keep a dedicated error for that case so the cause is clear. - for _, f := range fields { - if f.ObjectType == model.PropertyFieldObjectTypeTemplate { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.template_no_values.app_error", nil, "template fields cannot have values", http.StatusBadRequest) - return - } - if f.ObjectType != objectType { - c.Err = model.NewAppError("patchPropertyValues", "api.property_value.field_object_type_mismatch.app_error", nil, "", http.StatusBadRequest) - return - } - } - - // Build field map for permission checks - fieldMap := make(map[string]*model.PropertyField, len(fields)) - for _, f := range fields { - fieldMap[f.ID] = f - } - auditRec := c.MakeAuditRecord(model.AuditEventPatchPropertyValues, model.AuditStatusFail) defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "group_name", c.Params.GroupName) model.AddEventParameterToAuditRec(auditRec, "object_type", objectType) model.AddEventParameterToAuditRec(auditRec, "target_id", targetID) - // Check values permission on each field (all-or-nothing) + rctx := app.RequestContextWithCallerID(c.AppContext, sessionCallerID(c)) + + if !hasTargetAccess(c, objectType, targetID, true) { + return + } + + if len(items) == 0 { + c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.empty_body.app_error", nil, "", http.StatusBadRequest) + return + } + if len(items) > maxPropertyValuePatchItems { + c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.too_many_items.request_error", map[string]any{ + "Max": maxPropertyValuePatchItems, + }, "", http.StatusBadRequest) + return + } + + // Pre-validate IDs and de-dup so we can bulk-load fields for the + // session-bound permission check below. The App layer re-validates these + // invariants (defense for plugin/internal callers). + seen := map[string]bool{} + fieldIDs := make([]string, 0, len(items)) for _, item := range items { - field := fieldMap[item.FieldID] - if !c.App.SessionHasPermissionToSetPropertyFieldValues(c.AppContext, *c.AppContext.Session(), field) { + if !model.IsValidId(item.FieldID) { + c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.invalid_field_id.app_error", nil, "", http.StatusBadRequest) + return + } + if seen[item.FieldID] { + c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.duplicate_field_id.app_error", nil, "", http.StatusBadRequest) + return + } + seen[item.FieldID] = true + fieldIDs = append(fieldIDs, item.FieldID) + } + + fields, fieldsErr := c.App.GetPropertyFields(rctx, group.ID, fieldIDs) + if fieldsErr != nil { + c.Err = fieldsErr + return + } + fieldByID := make(map[string]*model.PropertyField, len(fields)) + for _, f := range fields { + fieldByID[f.ID] = f + } + for _, item := range items { + f, ok := fieldByID[item.FieldID] + if !ok { + c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.field_not_found.app_error", + map[string]any{"FieldID": item.FieldID}, "", http.StatusNotFound) + return + } + if f.ObjectType != objectType { + c.Err = model.NewAppError("patchPropertyValues", "api.property_field.object_type_mismatch.app_error", nil, "", http.StatusNotFound) + return + } + if !c.App.SessionHasPermissionToSetPropertyFieldValues(rctx, *c.AppContext.Session(), f) { c.Err = model.NewAppError("patchPropertyValues", "api.property_value.patch.no_values_permission.app_error", nil, "", http.StatusForbidden) return } } - // Build PropertyValue objects for upsert userID := c.AppContext.Session().UserId values := make([]*model.PropertyValue, len(items)) for i, item := range items { values[i] = &model.PropertyValue{ - TargetID: targetID, - // in PSAv2, values always point to entities of the same - // type that their field.ObjectType + TargetID: targetID, TargetType: objectType, - GroupID: groupID, + GroupID: group.ID, FieldID: item.FieldID, Value: item.Value, CreatedBy: userID, UpdatedBy: userID, } } - connectionID := r.Header.Get(model.ConnectionId) - upserted, err := c.App.UpsertPropertyValues(c.AppContext, values, objectType, targetID, connectionID) - if err != nil { - c.Err = err + upserted, upsertErr := c.App.UpsertPropertyValues(rctx, values, objectType, targetID, connectionID) + if upsertErr != nil { + c.Err = upsertErr return } @@ -767,11 +684,25 @@ func hasTargetAccess(c *Context, objectType, targetID string, write bool) bool { return false } case model.PropertyFieldObjectTypeUser: - // Any authenticated user can read another user's property values. - // Only the user themselves or a system admin can write values. - if write && targetID != c.AppContext.Session().UserId { - if !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) { - c.Err = model.NewAppError("hasTargetAccess", "api.property_value.target_user.forbidden.app_error", nil, "", http.StatusForbidden) + // Self-access and unrestricted sessions (local mode) always pass. + if targetID == c.AppContext.Session().UserId || c.AppContext.Session().IsUnrestricted() { + return true + } + if write { + // Writing another user's values requires PermissionEditOtherUsers. + if !c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionEditOtherUsers) { + c.SetPermissionError(model.PermissionEditOtherUsers) + return false + } + } else { + // Reading another user's values requires being able to see them. + canSee, appErr := c.App.UserCanSeeOtherUser(c.AppContext, c.AppContext.Session().UserId, targetID) + if appErr != nil { + c.Err = appErr + return false + } + if !canSee { + c.SetPermissionError(model.PermissionViewMembers) return false } } @@ -794,6 +725,18 @@ func hasTargetAccess(c *Context, objectType, targetID string, write bool) bool { return true } +// sessionCallerID returns the caller ID to attach to a request-derived rctx +// for property-service hook identification. Local-mode (unrestricted) +// sessions have an empty Session.UserId but full admin privileges, so they +// are tagged with CallerIDLocalAdmin instead. +func sessionCallerID(c *Context) string { + session := c.AppContext.Session() + if session.IsUnrestricted() { + return model.CallerIDLocalAdmin + } + return session.UserId +} + // isOptionsOnlyPatch checks if the patch only modifies the options attribute. // Returns true if the only change is to attrs.options. func isOptionsOnlyPatch(patch *model.PropertyFieldPatch) bool { diff --git a/server/channels/api4/properties_test.go b/server/channels/api4/properties_test.go index e673916f916..73489f9a612 100644 --- a/server/channels/api4/properties_test.go +++ b/server/channels/api4/properties_test.go @@ -901,13 +901,14 @@ func TestPatchPropertyField(t *testing.T) { newName := model.NewId() patch := &model.PropertyFieldPatch{Name: &newName} - // Try to update with wrong object_type in URL + // Try to update with wrong object_type in URL. Expected 404 to match + // the shape of a non-existent field. _, resp, err := th.SystemAdminClient.PatchPropertyField(context.Background(), group.Name, "channel", createdField.ID, patch) require.Error(t, err) - CheckBadRequestStatus(t, resp) + CheckNotFoundStatus(t, resp) }) - t.Run("patch with wrong group name should fail", func(t *testing.T) { + t.Run("patch with wrong group name should fail 404", func(t *testing.T) { field := &model.PropertyField{ Name: model.NewId(), Type: model.PropertyFieldTypeText, @@ -924,11 +925,12 @@ func TestPatchPropertyField(t *testing.T) { newName := model.NewId() patch := &model.PropertyFieldPatch{Name: &newName} - // Try to patch using the other group's name — field belongs to `group`, not `otherGroup` + // Try to patch using the other group's name — field belongs to `group`, not `otherGroup`. + // A field not found because of a wrong group must surface as 404, not a generic 500. _, resp, err := th.SystemAdminClient.PatchPropertyField(context.Background(), otherGroup.Name, "post", createdField.ID, patch) require.Error(t, err) - // GetPropertyField with the wrong groupID should not find the field - require.NotEqual(t, http.StatusOK, resp.StatusCode) + CheckNotFoundStatus(t, resp) + require.Equal(t, "app.property.not_found.app_error", err.(*model.AppError).Id) }) t.Run("options-only update should check options permission", func(t *testing.T) { @@ -1435,13 +1437,14 @@ func TestDeletePropertyField(t *testing.T) { createdField, appErr := th.App.CreatePropertyField(th.Context, field, false, "") require.Nil(t, appErr) - // Try to delete with wrong object_type in URL + // Try to delete with wrong object_type in URL. Expected 404 to match + // the shape of a non-existent field. resp, err := th.SystemAdminClient.DeletePropertyField(context.Background(), group.Name, "channel", createdField.ID) require.Error(t, err) - CheckBadRequestStatus(t, resp) + CheckNotFoundStatus(t, resp) }) - t.Run("delete with wrong group name should fail", func(t *testing.T) { + t.Run("delete with wrong group name should fail 404", func(t *testing.T) { field := &model.PropertyField{ Name: model.NewId(), Type: model.PropertyFieldTypeText, @@ -1455,12 +1458,13 @@ func TestDeletePropertyField(t *testing.T) { createdField, appErr := th.App.CreatePropertyField(th.Context, field, false, "") require.Nil(t, appErr) - // Try to delete using the other group's name — field belongs to `group`, not `otherGroup` - th.LoginBasic(t) - resp, err := th.Client.DeletePropertyField(context.Background(), otherGroup.Name, "post", createdField.ID) + // Try to delete using the other group's name — field belongs to `group`, not `otherGroup`. + // A field not found because of a wrong group must surface as 404, not a generic 500. + th.LoginSystemAdmin(t) + resp, err := th.SystemAdminClient.DeletePropertyField(context.Background(), otherGroup.Name, "post", createdField.ID) require.Error(t, err) - // GetPropertyField with the wrong groupID should not find the field - require.NotEqual(t, http.StatusOK, resp.StatusCode) + CheckNotFoundStatus(t, resp) + require.Equal(t, "app.property.not_found.app_error", err.(*model.AppError).Id) }) t.Run("user without permission should not be able to delete", func(t *testing.T) { @@ -1978,7 +1982,7 @@ func TestPatchPropertyValues(t *testing.T) { } _, resp, patchErr := th.Client.PatchPropertyValues(context.Background(), group.Name, "post", targetID, items) require.Error(t, patchErr) - CheckBadRequestStatus(t, resp) + CheckNotFoundStatus(t, resp) }) t.Run("nonexistent group should fail", func(t *testing.T) { @@ -1992,6 +1996,35 @@ func TestPatchPropertyValues(t *testing.T) { CheckNotFoundStatus(t, resp) }) + t.Run("field with mismatched object type should fail 404", func(t *testing.T) { + // A field in the same group but scoped to a different ObjectType must not + // be patchable through the URL of a peer ObjectType; the mismatch collapses + // to 404 so callers cannot distinguish "no such field" from "field exists + // but in a different object-type bucket". + userField := &model.PropertyField{ + Name: model.NewId(), + Type: model.PropertyFieldTypeText, + GroupID: group.ID, + ObjectType: "user", + TargetType: "system", + PermissionField: &memberLevel, + PermissionValues: &memberLevel, + PermissionOptions: &memberLevel, + } + createdUserField, appErr := th.App.CreatePropertyField(th.Context, userField, false, "") + require.Nil(t, appErr) + + th.LoginSystemAdmin(t) + + items := []model.PropertyValuePatchItem{ + {FieldID: createdUserField.ID, Value: json.RawMessage(`"test"`)}, + } + _, resp, err := th.SystemAdminClient.PatchPropertyValues(context.Background(), group.Name, "post", targetID, items) + require.Error(t, err) + CheckNotFoundStatus(t, resp) + require.Equal(t, "api.property_field.object_type_mismatch.app_error", err.(*model.AppError).Id) + }) + t.Run("channel member can set values on channel-scoped field with values permission member", func(t *testing.T) { th.LoginBasic(t) @@ -2246,6 +2279,23 @@ func TestGetPropertyValuesUserTargetAccess(t *testing.T) { CheckOKStatus(t, resp) require.NotEmpty(t, values) }) + + t.Run("non-admin cannot get values of a user they cannot see", func(t *testing.T) { + // Strip system-wide view_members so UserCanSeeOtherUser falls back to team/channel membership. + th.RemovePermissionFromRole(t, model.PermissionViewMembers.Id, model.SystemUserRoleId) + defer th.AddPermissionToRole(t, model.PermissionViewMembers.Id, model.SystemUserRoleId) + + // Drop BasicUser2 from BasicTeam so they no longer share a team with BasicUser. + resp, err := th.SystemAdminClient.RemoveTeamMember(context.Background(), th.BasicTeam.Id, th.BasicUser2.Id) + CheckOKStatus(t, resp) + require.NoError(t, err) + + th.LoginBasic2(t) + + _, resp, err = th.Client.GetPropertyValues(context.Background(), group.Name, "user", th.BasicUser.Id, model.PropertyValueSearch{PerPage: 60}) + CheckForbiddenStatus(t, resp) + require.Error(t, err) + }) } func TestPatchPropertyValuesUserTargetAccess(t *testing.T) { @@ -3353,7 +3403,9 @@ func TestSystemObjectType(t *testing.T) { } _, resp, patchErr := th.SystemAdminClient.PatchSystemPropertyValues(context.Background(), group.Name, items) require.Error(t, patchErr) - CheckBadRequestStatus(t, resp) + // Mismatch (template field ObjectType != system route's objectType) + // collapses to 404 to match the executePatchPropertyField shape. + CheckNotFoundStatus(t, resp) }) t.Run("system field round-trips a value via the dedicated route", func(t *testing.T) { @@ -3502,10 +3554,11 @@ func TestSystemObjectType(t *testing.T) { {FieldID: systemField.ID, Value: json.RawMessage(`"smuggled"`)}, } // Even sysadmin should be rejected — this is a structural check on - // the route, not a permission check. + // the route, not a permission check. Mismatch collapses to 404 to + // match the executePatchPropertyField/executeDeletePropertyField shape. _, resp, patchErr := th.SystemAdminClient.PatchPropertyValues(context.Background(), group.Name, model.PropertyFieldObjectTypeUser, th.SystemAdminUser.Id, items) require.Error(t, patchErr) - CheckBadRequestStatus(t, resp) + CheckNotFoundStatus(t, resp) }) t.Run("system values PATCH route rejects body referencing a non-system field ID", func(t *testing.T) { @@ -3531,6 +3584,6 @@ func TestSystemObjectType(t *testing.T) { } _, resp, patchErr := th.SystemAdminClient.PatchSystemPropertyValues(context.Background(), group.Name, items) require.Error(t, patchErr) - CheckBadRequestStatus(t, resp) + CheckNotFoundStatus(t, resp) }) } diff --git a/server/channels/app/access_control.go b/server/channels/app/access_control.go index 885c1ee745c..eb942db16ba 100644 --- a/server/channels/app/access_control.go +++ b/server/channels/app/access_control.go @@ -340,14 +340,14 @@ func (a *App) GetAccessControlPolicyAttributes(rctx request.CTX, channelID strin } func (a *App) GetAccessControlFieldsAutocomplete(rctx request.CTX, after string, limit int, callerID string) ([]*model.PropertyField, *model.AppError) { - cpaGroupID, appErr := a.CpaGroupID() + group, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) if appErr != nil { return nil, model.NewAppError("GetAccessControlAutoComplete", "app.pap.get_access_control_auto_complete.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) } // Use property app layer to enforce access control rctxWithCaller := RequestContextWithCallerID(rctx, callerID) - fields, appErr := a.SearchPropertyFields(rctxWithCaller, cpaGroupID, model.PropertyFieldSearchOpts{ + fields, appErr := a.SearchPropertyFields(rctxWithCaller, group.ID, model.PropertyFieldSearchOpts{ Cursor: model.PropertyFieldSearchCursor{ PropertyFieldID: after, CreateAt: 1, @@ -686,12 +686,12 @@ func (a *App) ValidateExpressionAgainstRequester(rctx request.CTX, expression st func (a *App) BuildAccessControlSubject(rctx request.CTX, userID string, roles string) (*model.Subject, *model.AppError) { a.refreshAttributeViewIfStale(rctx) - groupID, err := a.CpaGroupID() + group, err := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) if err != nil { return nil, model.NewAppError("BuildAccessControlSubject", "app.access_control.build_subject.group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } - subject, storeErr := a.Srv().Store().Attributes().GetSubject(rctx, userID, groupID) + subject, storeErr := a.Srv().Store().Attributes().GetSubject(rctx, userID, group.ID) if storeErr != nil { var nfErr *store.ErrNotFound if errors.As(storeErr, &nfErr) { diff --git a/server/channels/app/access_control_masking.go b/server/channels/app/access_control_masking.go index b3d9bd315f7..c64af04d9a6 100644 --- a/server/channels/app/access_control_masking.go +++ b/server/channels/app/access_control_masking.go @@ -30,10 +30,11 @@ func (a *App) GetMaskedVisualAST(rctx request.CTX, expression string, callerID s return visualAST, nil } - cpaGroupID, appErr := a.CpaGroupID() + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) if appErr != nil { return nil, model.NewAppError("GetMaskedVisualAST", "app.pap.get_masked_visual_ast.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) } + cpaGroupID := cpaGroup.ID // Embed callerID in context so GetPropertyFieldByName applies per-caller option filtering. rctxWithCaller := RequestContextWithCallerID(rctx, callerID) diff --git a/server/channels/app/access_control_masking_test.go b/server/channels/app/access_control_masking_test.go index 96242489e0e..e75b411b86e 100644 --- a/server/channels/app/access_control_masking_test.go +++ b/server/channels/app/access_control_masking_test.go @@ -609,20 +609,24 @@ func TestMaskConditionValues_SharedOnlyText(t *testing.T) { func TestGetMaskedVisualAST_Wiring(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) rctx := request.TestContext(t) + cpaGroup, cErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, cErr) + cpaID := cpaGroup.ID + callerID := model.NewId() // Create a plain public text field in the CPA group (no access mode = public). // Non-protected fields are writable by any caller in the CPA group. fieldName := "f_" + model.NewId() field := &model.PropertyField{ - GroupID: cpaID, - Name: fieldName, - Type: model.PropertyFieldTypeText, + GroupID: cpaID, + Name: fieldName, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } _, appErr := th.App.CreatePropertyField(rctx, field, false, "") require.Nil(t, appErr) diff --git a/server/channels/app/authorization_test.go b/server/channels/app/authorization_test.go index 35020f7e9b3..4dc921a22e6 100644 --- a/server/channels/app/authorization_test.go +++ b/server/channels/app/authorization_test.go @@ -17,6 +17,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/plugin/plugintest/mock" + "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks" ) @@ -1202,8 +1203,9 @@ func TestHasPermissionToEditPropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID testCases := []struct { name string @@ -1340,8 +1342,9 @@ func TestHasPermissionToSetPropertyFieldValues(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID // Create a user that is not a member of any channel for the non-member test case nonMember := th.CreateUser(t) @@ -1563,8 +1566,9 @@ func TestHasPermissionToManagePropertyFieldOptions(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID testCases := []struct { name string @@ -1701,8 +1705,9 @@ func TestSessionHasPermissionToEditPropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID testCases := []struct { name string @@ -1853,8 +1858,9 @@ func TestSessionHasPermissionToSetPropertyFieldValues(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID // Create a user that is not a member of any channel for the non-member test case nonMember := th.CreateUser(t) @@ -2075,8 +2081,9 @@ func TestSessionHasPermissionToManagePropertyFieldOptions(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + groupID := cpaGroup.ID testCases := []struct { name string diff --git a/server/channels/app/content_flagging.go b/server/channels/app/content_flagging.go index efad00bbbaf..eac209d023f 100644 --- a/server/channels/app/content_flagging.go +++ b/server/channels/app/content_flagging.go @@ -1167,14 +1167,14 @@ func (a *App) AssignFlaggedPostReviewer(rctx request.CTX, flaggedPostId, flagged Value: json.RawMessage(fmt.Sprintf(`"%s"`, reviewerId)), } - assigneePropertyValue, appErr = a.UpsertPropertyValue(nil, assigneePropertyValue) + assigneePropertyValue, appErr = a.UpsertPropertyValue(rctx, assigneePropertyValue) if appErr != nil { return model.NewAppError("AssignFlaggedPostReviewer", "app.data_spillage.assign_reviewer.upsert_property_value.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) } if status == model.ContentFlaggingStatusPending { statusPropertyValue.Value = json.RawMessage(fmt.Sprintf(`"%s"`, model.ContentFlaggingStatusAssigned)) - statusPropertyValue, appErr = a.UpdatePropertyValue(nil, groupId, statusPropertyValue) + statusPropertyValue, appErr = a.UpdatePropertyValue(rctx, groupId, statusPropertyValue) if appErr != nil { return model.NewAppError("AssignFlaggedPostReviewer", "app.data_spillage.assign_reviewer.update_status_property_value.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) } diff --git a/server/channels/app/custom_profile_attributes.go b/server/channels/app/custom_profile_attributes.go deleted file mode 100644 index e27cdf3c6f2..00000000000 --- a/server/channels/app/custom_profile_attributes.go +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See LICENSE.txt for license information. - -// This file implements the "User Attributes" feature (formerly "Custom -// Profile Attributes" / CPA). Internal identifiers retain the old naming -// for backward compatibility. See MM-68235. - -package app - -import ( - "encoding/json" - "errors" - "net/http" - "sort" - - "github.com/mattermost/mattermost/server/public/model" - "github.com/mattermost/mattermost/server/public/shared/mlog" - "github.com/mattermost/mattermost/server/public/shared/request" - "github.com/mattermost/mattermost/server/v8/channels/store" -) - -const ( - CustomProfileAttributesFieldLimit = 20 -) - -func (a *App) CpaGroupID() (string, *model.AppError) { - group, appErr := a.GetPropertyGroup(nil, model.CustomProfileAttributesPropertyGroupName) - if appErr != nil { - return "", appErr - } - return group.ID, nil -} - -func (a *App) GetCPAField(rctx request.CTX, fieldID string) (*model.CPAField, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("GetCPAField", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - field, appErr := a.GetPropertyField(rctx, groupID, fieldID) - if appErr != nil { - var notFoundErr *store.ErrNotFound - if errors.As(appErr, ¬FoundErr) { - return nil, model.NewAppError("GetCPAField", "app.custom_profile_attributes.property_field_not_found.app_error", nil, "", http.StatusNotFound).Wrap(appErr) - } - return nil, model.NewAppError("GetCPAField", "app.custom_profile_attributes.get_property_field.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - cpaField, err := model.NewCPAFieldFromPropertyField(field) - if err != nil { - return nil, model.NewAppError("GetCPAField", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - } - - return cpaField, nil -} - -func (a *App) ListCPAFields(rctx request.CTX) ([]*model.CPAField, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("ListCPAFields", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - opts := model.PropertyFieldSearchOpts{ - GroupID: groupID, - PerPage: CustomProfileAttributesFieldLimit, - } - - fields, appErr := a.SearchPropertyFields(rctx, groupID, opts) - if appErr != nil { - return nil, model.NewAppError("ListCPAFields", "app.custom_profile_attributes.search_property_fields.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - // Convert PropertyFields to CPAFields - cpaFields := make([]*model.CPAField, 0, len(fields)) - for _, field := range fields { - cpaField, convErr := model.NewCPAFieldFromPropertyField(field) - if convErr != nil { - return nil, model.NewAppError("ListCPAFields", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(convErr) - } - cpaFields = append(cpaFields, cpaField) - } - - sort.Slice(cpaFields, func(i, j int) bool { - return cpaFields[i].Attrs.SortOrder < cpaFields[j].Attrs.SortOrder - }) - - return cpaFields, nil -} - -func (a *App) CreateCPAField(rctx request.CTX, field *model.CPAField) (*model.CPAField, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("CreateCPAField", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - fieldCount, appErr := a.CountPropertyFieldsForGroup(rctx, groupID, false) - if appErr != nil { - return nil, model.NewAppError("CreateCPAField", "app.custom_profile_attributes.count_property_fields.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - if fieldCount >= CustomProfileAttributesFieldLimit { - return nil, model.NewAppError("CreateCPAField", "app.custom_profile_attributes.limit_reached.app_error", nil, "", http.StatusUnprocessableEntity) - } - - field.GroupID = groupID - - if appErr = field.SanitizeAndValidate(); appErr != nil { - return nil, appErr - } - - if appErr = model.ValidateCPAFieldName(field.Name); appErr != nil { - return nil, appErr - } - - newField, appErr := a.CreatePropertyField(rctx, field.ToPropertyField(), false, "") - if appErr != nil { - return nil, appErr - } - - cpaField, err := model.NewCPAFieldFromPropertyField(newField) - if err != nil { - return nil, model.NewAppError("CreateCPAField", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - } - - message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldCreated, "", "", "", nil, "") - message.Add("field", cpaField) - a.Publish(message) - - return cpaField, nil -} - -func (a *App) PatchCPAField(rctx request.CTX, fieldID string, patch *model.PropertyFieldPatch) (*model.CPAField, *model.AppError) { - existingField, appErr := a.GetCPAField(rctx, fieldID) - if appErr != nil { - return nil, appErr - } - originalName := existingField.Name - - shouldDeleteValues := false - if patch.Type != nil && *patch.Type != existingField.Type { - shouldDeleteValues = true - } - - if err := existingField.Patch(patch); err != nil { - return nil, model.NewAppError("PatchCPAField", "app.custom_profile_attributes.patch_field.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - } - - if appErr = existingField.SanitizeAndValidate(); appErr != nil { - return nil, appErr - } - - // Lenient grandfather: only validate Name against CEL rules when it actually changes. - // Pre-existing fields with invalid names remain editable on all other attrs. - if existingField.Name != originalName { - if appErr = model.ValidateCPAFieldName(existingField.Name); appErr != nil { - return nil, appErr - } - } - - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("PatchCPAField", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - patchedField, appErr := a.UpdatePropertyField(rctx, groupID, existingField.ToPropertyField(), false, "") - if appErr != nil { - var notFoundErr *store.ErrNotFound - if errors.As(appErr, ¬FoundErr) { - return nil, model.NewAppError("PatchCPAField", "app.custom_profile_attributes.property_field_not_found.app_error", nil, "", http.StatusNotFound).Wrap(appErr) - } - return nil, model.NewAppError("PatchCPAField", "app.custom_profile_attributes.property_field_update.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - cpaField, err := model.NewCPAFieldFromPropertyField(patchedField) - if err != nil { - return nil, model.NewAppError("PatchCPAField", "app.custom_profile_attributes.property_field_conversion.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - } - - if shouldDeleteValues { - if dErr := a.DeletePropertyValuesForField(rctx, groupID, cpaField.ID); dErr != nil { - a.Log().Error("Error deleting property values when updating field", - mlog.String("fieldID", cpaField.ID), - mlog.Err(dErr), - ) - } - } - - message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldUpdated, "", "", "", nil, "") - message.Add("field", cpaField) - message.Add("delete_values", shouldDeleteValues) - a.Publish(message) - - return cpaField, nil -} - -func (a *App) DeleteCPAField(rctx request.CTX, id string) *model.AppError { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return model.NewAppError("DeleteCPAField", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - if appErr := a.DeletePropertyField(rctx, groupID, id, false, ""); appErr != nil { - var notFoundErr *store.ErrNotFound - if errors.As(appErr, ¬FoundErr) { - return model.NewAppError("DeleteCPAField", "app.custom_profile_attributes.property_field_not_found.app_error", nil, "", http.StatusNotFound).Wrap(appErr) - } - return model.NewAppError("DeleteCPAField", "app.custom_profile_attributes.property_field_delete.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - message := model.NewWebSocketEvent(model.WebsocketEventCPAFieldDeleted, "", "", "", nil, "") - message.Add("field_id", id) - a.Publish(message) - - return nil -} - -func (a *App) ListCPAValues(rctx request.CTX, targetUserID string) ([]*model.PropertyValue, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("ListCPAValues", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - values, appErr := a.SearchPropertyValues(rctx, groupID, model.PropertyValueSearchOpts{ - TargetIDs: []string{targetUserID}, - PerPage: CustomProfileAttributesFieldLimit, - }) - if appErr != nil { - return nil, model.NewAppError("ListCPAValues", "app.custom_profile_attributes.list_property_values.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - return values, nil -} - -func (a *App) GetCPAValue(rctx request.CTX, valueID string) (*model.PropertyValue, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("GetCPAValue", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - value, appErr := a.GetPropertyValue(rctx, groupID, valueID) - if appErr != nil { - return nil, model.NewAppError("GetCPAValue", "app.custom_profile_attributes.get_property_value.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - return value, nil -} - -func (a *App) PatchCPAValue(rctx request.CTX, userID string, fieldID string, value json.RawMessage, allowSynced bool) (*model.PropertyValue, *model.AppError) { - values, appErr := a.PatchCPAValues(rctx, userID, map[string]json.RawMessage{fieldID: value}, allowSynced) - if appErr != nil { - return nil, appErr - } - - return values[0], nil -} - -func (a *App) PatchCPAValues(rctx request.CTX, userID string, fieldValueMap map[string]json.RawMessage, allowSynced bool) ([]*model.PropertyValue, *model.AppError) { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - valuesToUpdate := []*model.PropertyValue{} - for fieldID, rawValue := range fieldValueMap { - // make sure field exists in this group - cpaField, fieldErr := a.GetCPAField(rctx, fieldID) - if fieldErr != nil { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.property_field_not_found.app_error", nil, "", http.StatusNotFound).Wrap(fieldErr) - } else if cpaField.DeleteAt > 0 { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.property_field_not_found.app_error", nil, "", http.StatusNotFound) - } - - if !allowSynced && cpaField.IsSynced() { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.property_field_is_synced.app_error", nil, "", http.StatusBadRequest) - } - - sanitizedValue, sErr := model.SanitizeAndValidatePropertyValue(cpaField, rawValue) - if sErr != nil { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.validate_value.app_error", nil, "", http.StatusBadRequest).Wrap(sErr) - } - - value := &model.PropertyValue{ - GroupID: groupID, - TargetType: model.PropertyValueTargetTypeUser, - TargetID: userID, - FieldID: fieldID, - Value: sanitizedValue, - } - valuesToUpdate = append(valuesToUpdate, value) - } - - updatedValues, appErr := a.UpsertPropertyValues(rctx, valuesToUpdate, "", "", "") - if appErr != nil { - return nil, model.NewAppError("PatchCPAValues", "app.custom_profile_attributes.property_value_upsert.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - updatedFieldValueMap := map[string]json.RawMessage{} - for _, value := range updatedValues { - updatedFieldValueMap[value.FieldID] = value.Value - } - - message := model.NewWebSocketEvent(model.WebsocketEventCPAValuesUpdated, "", "", "", nil, "") - message.Add("user_id", userID) - message.Add("values", updatedFieldValueMap) - a.Publish(message) - - return updatedValues, nil -} - -func (a *App) DeleteCPAValues(rctx request.CTX, userID string) *model.AppError { - groupID, appErr := a.CpaGroupID() - if appErr != nil { - return model.NewAppError("DeleteCPAValues", "app.custom_profile_attributes.cpa_group_id.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - if appErr := a.DeletePropertyValuesForTarget(rctx, groupID, "user", userID); appErr != nil { - return model.NewAppError("DeleteCPAValues", "app.custom_profile_attributes.delete_property_values_for_user.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) - } - - message := model.NewWebSocketEvent(model.WebsocketEventCPAValuesUpdated, "", "", "", nil, "") - message.Add("user_id", userID) - message.Add("values", map[string]json.RawMessage{}) - a.Publish(message) - - return nil -} diff --git a/server/channels/app/custom_profile_attributes_test.go b/server/channels/app/custom_profile_attributes_test.go index ebdc85d26b3..74688ce4707 100644 --- a/server/channels/app/custom_profile_attributes_test.go +++ b/server/channels/app/custom_profile_attributes_test.go @@ -6,810 +6,37 @@ package app import ( "encoding/json" "fmt" - "net/http" "testing" - "time" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" "github.com/stretchr/testify/require" ) -func celSafeName() string { - return "f_" + model.NewId() -} - -func TestGetCPAField(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - t.Run("should fail when getting a non-existent field", func(t *testing.T) { - field, appErr := th.App.GetCPAField(rctx, model.NewId()) - require.NotNil(t, appErr) - require.Equal(t, "app.custom_profile_attributes.property_field_not_found.app_error", appErr.Id) - require.Empty(t, field) - }) - - t.Run("should fail when getting a field from a different group", func(t *testing.T) { - otherGroup, gErr := th.App.RegisterPropertyGroup(rctx, &model.PropertyGroup{ - Name: "test_get_cpa_other_group_" + model.NewId(), - Version: model.PropertyGroupVersionV1, - }) - require.Nil(t, gErr) - - field := &model.PropertyField{ - GroupID: otherGroup.ID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, - } - createdField, err := th.App.CreatePropertyField(rctx, field, false, "") - require.Nil(t, err) - - fetchedField, appErr := th.App.GetCPAField(rctx, createdField.ID) - require.NotNil(t, appErr) - require.Equal(t, "app.custom_profile_attributes.property_field_not_found.app_error", appErr.Id) - require.Empty(t, fetchedField) - }) - - t.Run("should get an existing CPA field", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "test_field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityHidden}, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.NotEmpty(t, createdField.ID) - - fetchedField, appErr := th.App.GetCPAField(rctx, createdField.ID) - require.Nil(t, appErr) - require.Equal(t, createdField.ID, fetchedField.ID) - require.Equal(t, "test_field", fetchedField.Name) - require.Equal(t, model.CustomProfileAttributesVisibilityHidden, fetchedField.Attrs.Visibility) - }) - - t.Run("should initialize default attrs when field has nil Attrs", func(t *testing.T) { - // Create a field with nil Attrs directly via property service (bypassing CPA validation) - field := &model.PropertyField{ - GroupID: cpaID, - Name: "Field with nil attrs", - Type: model.PropertyFieldTypeText, - Attrs: nil, - } - createdField, err := th.App.CreatePropertyField(rctx, field, false, "") - require.Nil(t, err) - - // GetCPAField should initialize Attrs with defaults - fetchedField, appErr := th.App.GetCPAField(rctx, createdField.ID) - require.Nil(t, appErr) - require.Equal(t, model.CustomProfileAttributesVisibilityDefault, fetchedField.Attrs.Visibility) - require.Equal(t, float64(0), fetchedField.Attrs.SortOrder) - }) - - t.Run("should initialize default attrs when field has empty Attrs", func(t *testing.T) { - // Create a field with empty Attrs directly via property service - field := &model.PropertyField{ - GroupID: cpaID, - Name: "Field with empty attrs", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, - } - createdField, err := th.App.CreatePropertyField(rctx, field, false, "") - require.Nil(t, err) - - // GetCPAField should add missing default attrs - fetchedField, appErr := th.App.GetCPAField(rctx, createdField.ID) - require.Nil(t, appErr) - require.Equal(t, model.CustomProfileAttributesVisibilityDefault, fetchedField.Attrs.Visibility) - require.Equal(t, float64(0), fetchedField.Attrs.SortOrder) - }) - - t.Run("should validate LDAP/SAML synced fields", func(t *testing.T) { - // Create LDAP synced field - ldapField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "ldap_field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{ - model.CustomProfileAttributesPropertyAttrsLDAP: "ldap_attribute", - }, - }) - require.NoError(t, err) - createdLDAPField, appErr := th.App.CreateCPAField(rctx, ldapField) - require.Nil(t, appErr) - - // Create SAML synced field - samlField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "saml_field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{ - model.CustomProfileAttributesPropertyAttrsSAML: "saml_attribute", - }, - }) - require.NoError(t, err) - createdSAMLField, appErr := th.App.CreateCPAField(rctx, samlField) - require.Nil(t, appErr) - - // Test with allowSynced=false - userID := model.NewId() - - // Test LDAP field - _, appErr = th.App.PatchCPAValue(rctx, userID, createdLDAPField.ID, json.RawMessage(`"test value"`), false) - require.NotNil(t, appErr) - require.Equal(t, "app.custom_profile_attributes.property_field_is_synced.app_error", appErr.Id) - - // Test SAML field - _, appErr = th.App.PatchCPAValue(rctx, userID, createdSAMLField.ID, json.RawMessage(`"test value"`), false) - require.NotNil(t, appErr) - require.Equal(t, "app.custom_profile_attributes.property_field_is_synced.app_error", appErr.Id) - - // Test with allowSynced=true - // LDAP field should work - patchedValue, appErr := th.App.PatchCPAValue(rctx, userID, createdLDAPField.ID, json.RawMessage(`"test value"`), true) - require.Nil(t, appErr) - require.NotNil(t, patchedValue) - require.Equal(t, json.RawMessage(`"test value"`), patchedValue.Value) - - // SAML field should work - patchedValue, appErr = th.App.PatchCPAValue(rctx, userID, createdSAMLField.ID, json.RawMessage(`"test value"`), true) - require.Nil(t, appErr) - require.NotNil(t, patchedValue) - require.Equal(t, json.RawMessage(`"test value"`), patchedValue.Value) - }) -} - -func TestListCPAFields(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - t.Run("should list the CPA property fields", func(t *testing.T) { - field1 := model.PropertyField{ - GroupID: cpaID, - Name: "Field 1", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsSortOrder: 1}, - } - - _, err := th.App.CreatePropertyField(rctx, &field1, false, "") - require.Nil(t, err) - - otherGroup, gErr := th.App.RegisterPropertyGroup(rctx, &model.PropertyGroup{ - Name: "test_list_cpa_other_group_" + model.NewId(), - Version: model.PropertyGroupVersionV1, - }) - require.Nil(t, gErr) - - field2 := &model.PropertyField{ - GroupID: otherGroup.ID, - Name: "Field 2", - Type: model.PropertyFieldTypeText, - } - _, err = th.App.CreatePropertyField(rctx, field2, false, "") - require.Nil(t, err) - - field3 := model.PropertyField{ - GroupID: cpaID, - Name: "Field 3", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsSortOrder: 0}, - } - _, err = th.App.CreatePropertyField(rctx, &field3, false, "") - require.Nil(t, err) - - fields, appErr := th.App.ListCPAFields(rctx) - require.Nil(t, appErr) - require.Len(t, fields, 2) - require.Equal(t, "Field 3", fields[0].Name) - require.Equal(t, "Field 1", fields[1].Name) - }) - - t.Run("should initialize default attrs for fields with nil or empty Attrs", func(t *testing.T) { - // Create a field with nil Attrs - fieldWithNilAttrs := &model.PropertyField{ - GroupID: cpaID, - Name: "Field with nil attrs", - Type: model.PropertyFieldTypeText, - Attrs: nil, - } - _, err := th.App.CreatePropertyField(rctx, fieldWithNilAttrs, false, "") - require.Nil(t, err) - - // Create a field with empty Attrs - fieldWithEmptyAttrs := &model.PropertyField{ - GroupID: cpaID, - Name: "Field with empty attrs", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, - } - _, err = th.App.CreatePropertyField(rctx, fieldWithEmptyAttrs, false, "") - require.Nil(t, err) - - // ListCPAFields should initialize Attrs with defaults - fields, appErr := th.App.ListCPAFields(rctx) - require.Nil(t, appErr) - require.NotEmpty(t, fields) - - // Find our test fields and verify default attrs are set - for _, field := range fields { - if field.Name == "Field with nil attrs" || field.Name == "Field with empty attrs" { - require.Equal(t, model.CustomProfileAttributesVisibilityDefault, field.Attrs.Visibility) - require.Equal(t, float64(0), field.Attrs.SortOrder) - } - } - }) - - t.Run("list fields should return defaults for fields created without visibility and sort_order", func(t *testing.T) { - // Create a field with minimal attrs (no visibility or sort_order) - fieldMinimal, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: "field_without_defaults", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, // Empty attrs - no visibility or sort_order - }) - require.NoError(t, err) - createdFieldMinimal, appErr := th.App.CreateCPAField(rctx, fieldMinimal) - require.Nil(t, appErr) - require.NotNil(t, createdFieldMinimal) - - // Create another field to ensure we test list results with explicit values - fieldNormal, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: "normal_field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{ - model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityAlways, - model.CustomProfileAttributesPropertyAttrsSortOrder: 5.0, - }, - }) - require.NoError(t, err) - createdFieldNormal, appErr := th.App.CreateCPAField(rctx, fieldNormal) - require.Nil(t, appErr) - require.NotNil(t, createdFieldNormal) - - // List all fields - fields, appErr := th.App.ListCPAFields(rctx) - require.Nil(t, appErr) - require.NotEmpty(t, fields) - - // Find our test fields and verify defaults - foundMinimal := false - foundNormal := false - for _, f := range fields { - if f.ID == createdFieldMinimal.ID { - foundMinimal = true - // Verify defaults are set for field created without them - require.Equal(t, model.CustomProfileAttributesVisibilityDefault, f.Attrs.Visibility, "visibility should have default value") - require.Equal(t, float64(0), f.Attrs.SortOrder, "sort_order should default to 0") - } - if f.ID == createdFieldNormal.ID { - foundNormal = true - // Verify createdFieldNormal are preserved - require.Equal(t, model.CustomProfileAttributesVisibilityAlways, f.Attrs.Visibility) - require.Equal(t, float64(5), f.Attrs.SortOrder) - } - } - require.True(t, foundMinimal, "should have found createdFieldMinimal in list") - require.True(t, foundNormal, "should have found createdFieldNormal in list") - }) -} - -func TestCreateCPAField(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - t.Run("should fail if the field is not valid", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{Name: celSafeName()}) - require.NoError(t, err) - - createdField, err := th.App.CreateCPAField(rctx, field) - require.Error(t, err) - require.Empty(t, createdField) - }) - - t.Run("should not be able to create a property field for a different feature", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: model.NewId(), - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.Equal(t, cpaID, createdField.GroupID) - }) - - t.Run("should correctly create a CPA field", func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityHidden}, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.NotZero(t, createdField.ID) - require.Equal(t, cpaID, createdField.GroupID) - require.Equal(t, model.CustomProfileAttributesVisibilityHidden, createdField.Attrs.Visibility) - - fetchedField, gErr := th.App.GetPropertyField(rctx, "", createdField.ID) - require.Nil(t, gErr) - require.Equal(t, field.Name, fetchedField.Name) - require.NotZero(t, fetchedField.CreateAt) - require.Equal(t, fetchedField.CreateAt, fetchedField.UpdateAt) - }) - - t.Run("should create CPA field with DeleteAt set to 0 even if input has non-zero DeleteAt", func(t *testing.T) { - // Create a CPAField with DeleteAt != 0 - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityHidden}, - }) - require.NoError(t, err) - - // Set DeleteAt to non-zero value before creation - field.DeleteAt = time.Now().UnixMilli() - require.NotZero(t, field.DeleteAt, "Pre-condition: field should have non-zero DeleteAt") - - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.NotZero(t, createdField.ID) - require.Equal(t, cpaID, createdField.GroupID) - - // Verify that DeleteAt has been reset to 0 - require.Zero(t, createdField.DeleteAt, "DeleteAt should be 0 after creation") - - // Double-check by fetching the field from the database - fetchedField, gErr := th.App.GetPropertyField(rctx, "", createdField.ID) - require.Nil(t, gErr) - require.Zero(t, fetchedField.DeleteAt, "DeleteAt should be 0 in database") - }) - - t.Run("CPA should honor the field limit", func(t *testing.T) { - th := Setup(t).InitBasic(t) - - t.Run("should not be able to create CPA fields above the limit", func(t *testing.T) { - // we create the rest of the fields required to reach the limit - for i := 1; i <= CustomProfileAttributesFieldLimit; i++ { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: fmt.Sprintf("f_%d_%s", i, model.NewId()), - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.NotZero(t, createdField.ID) - } - - // then, we create a last one that would exceed the limit - field := &model.CPAField{ - PropertyField: model.PropertyField{ - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - }, - } - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.NotNil(t, appErr) - require.Equal(t, http.StatusUnprocessableEntity, appErr.StatusCode) - require.Zero(t, createdField) - }) - - t.Run("deleted fields should not count for the limit", func(t *testing.T) { - // we retrieve the list of fields and check we've reached the limit - fields, appErr := th.App.ListCPAFields(rctx) - require.Nil(t, appErr) - require.Len(t, fields, CustomProfileAttributesFieldLimit) - - // then we delete one field - require.Nil(t, th.App.DeleteCPAField(rctx, fields[0].ID)) - - // creating a new one should work now - field := &model.CPAField{ - PropertyField: model.PropertyField{ - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - }, - } - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - require.NotZero(t, createdField.ID) - }) - }) -} - -func TestPatchCPAField(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - newField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityHidden}, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, newField) - require.Nil(t, appErr) - - patch := &model.PropertyFieldPatch{ - Name: new("patched_name"), - Attrs: new(model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityWhenSet}), - TargetID: new(model.NewId()), - TargetType: new(model.NewId()), - } - - t.Run("should fail if the field doesn't exist", func(t *testing.T) { - updatedField, appErr := th.App.PatchCPAField(rctx, model.NewId(), patch) - require.NotNil(t, appErr) - require.Empty(t, updatedField) - }) - - t.Run("should not allow to patch a field outside of CPA", func(t *testing.T) { - otherGroup, gErr := th.App.RegisterPropertyGroup(rctx, &model.PropertyGroup{ - Name: "test_patch_cpa_other_group_" + model.NewId(), - Version: model.PropertyGroupVersionV1, - }) - require.Nil(t, gErr) - - newField := &model.PropertyField{ - GroupID: otherGroup.ID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, - } - - field, err := th.App.CreatePropertyField(rctx, newField, false, "") - require.Nil(t, err) - - updatedField, uErr := th.App.PatchCPAField(rctx, field.ID, patch) - require.NotNil(t, uErr) - require.Equal(t, "app.custom_profile_attributes.property_field_not_found.app_error", uErr.Id) - require.Empty(t, updatedField) - }) - - t.Run("should correctly patch the CPA property field", func(t *testing.T) { - time.Sleep(10 * time.Millisecond) // ensure the UpdateAt is different than CreateAt - - updatedField, appErr := th.App.PatchCPAField(rctx, createdField.ID, patch) - require.Nil(t, appErr) - require.Equal(t, createdField.ID, updatedField.ID) - require.Equal(t, "patched_name", updatedField.Name) - require.Equal(t, model.CustomProfileAttributesVisibilityWhenSet, updatedField.Attrs.Visibility) - require.Empty(t, updatedField.TargetID, "CPA should not allow to patch the field's target ID") - require.Empty(t, updatedField.TargetType, "CPA should not allow to patch the field's target type") - require.Greater(t, updatedField.UpdateAt, createdField.UpdateAt) - }) - - t.Run("should preserve option IDs when patching select field options", func(t *testing.T) { - // Create a select field with options - selectField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "select_field", - Type: model.PropertyFieldTypeSelect, - Attrs: map[string]any{ - model.PropertyFieldAttributeOptions: []any{ - map[string]any{ - "name": "Option 1", - "color": "#111111", - }, - map[string]any{ - "name": "Option 2", - "color": "#222222", - }, - }, - }, - }) - require.NoError(t, err) - - createdSelectField, appErr := th.App.CreateCPAField(rctx, selectField) - require.Nil(t, appErr) - - // Get the original option IDs - options := createdSelectField.Attrs.Options - require.Len(t, options, 2) - originalID1 := options[0].ID - originalID2 := options[1].ID - require.NotEmpty(t, originalID1) - require.NotEmpty(t, originalID2) - - // Patch the field with updated option names and colors - selectPatch := &model.PropertyFieldPatch{ - Attrs: new(model.StringInterface{ - model.PropertyFieldAttributeOptions: []any{ - map[string]any{ - "id": originalID1, - "name": "Updated Option 1", - "color": "#333333", - }, - map[string]any{ - "name": "New Option 1.5", - "color": "#353535", - }, - map[string]any{ - "id": originalID2, - "name": "Updated Option 2", - "color": "#444444", - }, - }, - }), - } - - updatedSelectField, appErr := th.App.PatchCPAField(rctx, createdSelectField.ID, selectPatch) - require.Nil(t, appErr) - - updatedOptions := updatedSelectField.Attrs.Options - require.Len(t, updatedOptions, 3) - - // Verify the options were updated while preserving IDs - require.Equal(t, originalID1, updatedOptions[0].ID) - require.Equal(t, "Updated Option 1", updatedOptions[0].Name) - require.Equal(t, "#333333", updatedOptions[0].Color) - require.Equal(t, originalID2, updatedOptions[2].ID) - require.Equal(t, "Updated Option 2", updatedOptions[2].Name) - require.Equal(t, "#444444", updatedOptions[2].Color) - - // Check the new option - require.Equal(t, "New Option 1.5", updatedOptions[1].Name) - require.Equal(t, "#353535", updatedOptions[1].Color) - }) - - t.Run("Should not delete the values of a field after patching it if the type has not changed", func(t *testing.T) { - // Create a select field with options - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "select_field_with_values", - Type: model.PropertyFieldTypeSelect, - Attrs: model.StringInterface{ - model.PropertyFieldAttributeOptions: []any{ - map[string]any{ - "name": "Option 1", - "color": "#FF5733", - }, - map[string]any{ - "name": "Option 2", - "color": "#33FF57", - }, - }, - }, - }) - require.NoError(t, err) - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - - // Get the option IDs - options := createdField.Attrs.Options - require.Len(t, options, 2) - optionID := options[0].ID - require.NotEmpty(t, optionID) - - // Create values for this field using the first option - userID := model.NewId() - value, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), false) - require.Nil(t, appErr) - require.NotNil(t, value) - - // Patch the field without changing type (just update name and add a new option) - patch := &model.PropertyFieldPatch{ - Name: new("updated_select_field_name"), - Attrs: new(model.StringInterface{ - model.PropertyFieldAttributeOptions: []any{ - map[string]any{ - "id": optionID, // Keep the same ID for the first option - "name": "Updated Option 1", - "color": "#FF5733", - }, - map[string]any{ - "name": "Option 2", - "color": "#33FF57", - }, - map[string]any{ - "name": "Option 3", - "color": "#5733FF", - }, - }, - }), - } - updatedField, appErr := th.App.PatchCPAField(rctx, createdField.ID, patch) - require.Nil(t, appErr) - require.Equal(t, "updated_select_field_name", updatedField.Name) - require.Equal(t, model.PropertyFieldTypeSelect, updatedField.Type) - - // Verify values still exist - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Len(t, values, 1) - require.Equal(t, json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), values[0].Value) - }) - - t.Run("Should delete the values of a field after patching it if the type has changed", func(t *testing.T) { - // Create a select field with options - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: "select_field_with_type_change", - Type: model.PropertyFieldTypeSelect, - Attrs: model.StringInterface{ - model.PropertyFieldAttributeOptions: []any{ - map[string]any{ - "name": "Option A", - "color": "#FF5733", - }, - map[string]any{ - "name": "Option B", - "color": "#33FF57", - }, - }, - }, - }) - require.NoError(t, err) - createdField, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr) - - // Get the option IDs - options := createdField.Attrs.Options - require.Len(t, options, 2) - optionID := options[0].ID - require.NotEmpty(t, optionID) - - // Create values for this field - userID := model.NewId() - value, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(fmt.Sprintf(`"%s"`, optionID)), false) - require.Nil(t, appErr) - require.NotNil(t, value) - - // Verify value exists before type change - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Len(t, values, 1) - - // Patch the field and change type from select to text - patch := &model.PropertyFieldPatch{ - Type: model.NewPointer(model.PropertyFieldTypeText), - } - updatedField, appErr := th.App.PatchCPAField(rctx, createdField.ID, patch) - require.Nil(t, appErr) - require.Equal(t, model.PropertyFieldTypeText, updatedField.Type) - - // Verify values have been deleted - values, appErr = th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Empty(t, values) - }) -} - -func TestDeleteCPAField(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - newField, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: celSafeName(), - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) - - createdField, appErr := th.App.CreateCPAField(rctx, newField) - require.Nil(t, appErr) - - for i := range 3 { - newValue := &model.PropertyValue{ - TargetID: model.NewId(), - TargetType: model.PropertyValueTargetTypeUser, - GroupID: cpaID, - FieldID: createdField.ID, - Value: json.RawMessage(fmt.Sprintf(`"Value %d"`, i)), - } - value, err := th.App.CreatePropertyValue(rctx, newValue) - require.Nil(t, err) - require.NotZero(t, value.ID) - } - - t.Run("should fail if the field doesn't exist", func(t *testing.T) { - err := th.App.DeleteCPAField(rctx, model.NewId()) - require.NotNil(t, err) - require.Equal(t, "app.custom_profile_attributes.property_field_not_found.app_error", err.Id) - }) - - t.Run("should not allow to delete a field outside of CPA", func(t *testing.T) { - otherGroup, gErr := th.App.RegisterPropertyGroup(rctx, &model.PropertyGroup{ - Name: "test_delete_cpa_other_group_" + model.NewId(), - Version: model.PropertyGroupVersionV1, - }) - require.Nil(t, gErr) - - newField := &model.PropertyField{ - GroupID: otherGroup.ID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, - } - field, err := th.App.CreatePropertyField(rctx, newField, false, "") - require.Nil(t, err) - - dErr := th.App.DeleteCPAField(rctx, field.ID) - require.NotNil(t, dErr) - require.Equal(t, "app.custom_profile_attributes.property_field_not_found.app_error", dErr.Id) - }) - - t.Run("should correctly delete the field", func(t *testing.T) { - // check that we have the associated values to the field prior deletion - opts := model.PropertyValueSearchOpts{PerPage: 10, FieldID: createdField.ID} - values, err := th.App.SearchPropertyValues(rctx, cpaID, opts) - require.Nil(t, err) - require.Len(t, values, 3) - - // delete the field - require.Nil(t, th.App.DeleteCPAField(rctx, createdField.ID)) - - // check that it is marked as deleted - fetchedField, err := th.App.GetPropertyField(rctx, "", createdField.ID) - require.Nil(t, err) - require.NotZero(t, fetchedField.DeleteAt) - - // ensure that the associated fields have been marked as deleted too - values, err = th.App.SearchPropertyValues(rctx, cpaID, opts) - require.Nil(t, err) - require.Len(t, values, 0) - - opts.IncludeDeleted = true - values, err = th.App.SearchPropertyValues(rctx, cpaID, opts) - require.Nil(t, err) - require.Len(t, values, 3) - for _, value := range values { - require.NotZero(t, value.DeleteAt) - } - }) -} - func TestGetCPAValue(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID rctx := th.emptyContextWithCallerID(anonymousCallerId) field := &model.PropertyField{ - GroupID: cpaID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, + GroupID: cpaID, + Name: "f_" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } createdField, err := th.App.CreatePropertyField(rctx, field, false, "") require.Nil(t, err) fieldID := createdField.ID t.Run("should fail if the value doesn't exist", func(t *testing.T) { - pv, appErr := th.App.GetCPAValue(rctx, model.NewId()) + pv, appErr := th.App.GetPropertyValue(rctx, cpaID, model.NewId()) require.NotNil(t, appErr) require.Nil(t, pv) }) @@ -826,7 +53,7 @@ func TestGetCPAValue(t *testing.T) { require.Nil(t, err) require.NotNil(t, created) - pv, appErr := th.App.GetCPAValue(rctx, created.ID) + pv, appErr := th.App.GetPropertyValue(rctx, cpaID, created.ID) require.NotNil(t, appErr) require.Nil(t, pv) }) @@ -842,16 +69,26 @@ func TestGetCPAValue(t *testing.T) { propertyValue, err := th.App.CreatePropertyValue(rctx, propertyValue) require.Nil(t, err) - pv, appErr := th.App.GetCPAValue(rctx, propertyValue.ID) + pv, appErr := th.App.GetPropertyValue(rctx, cpaID, propertyValue.ID) require.Nil(t, appErr) require.NotNil(t, pv) }) t.Run("should handle array values correctly", func(t *testing.T) { + optionIDs := []string{model.NewId(), model.NewId(), model.NewId()} arrayField := &model.PropertyField{ - GroupID: cpaID, - Name: model.NewId(), - Type: model.PropertyFieldTypeMultiselect, + GroupID: cpaID, + Name: "f_" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionIDs[0], "name": "option1"}, + map[string]any{"id": optionIDs[1], "name": "option2"}, + map[string]any{"id": optionIDs[2], "name": "option3"}, + }, + }, } createdField, err := th.App.CreatePropertyField(rctx, arrayField, false, "") require.Nil(t, err) @@ -861,202 +98,17 @@ func TestGetCPAValue(t *testing.T) { TargetType: model.PropertyValueTargetTypeUser, GroupID: cpaID, FieldID: createdField.ID, - Value: json.RawMessage(`["option1", "option2", "option3"]`), + Value: json.RawMessage(fmt.Sprintf(`["%s", "%s", "%s"]`, optionIDs[0], optionIDs[1], optionIDs[2])), } propertyValue, err = th.App.CreatePropertyValue(rctx, propertyValue) require.Nil(t, err) - pv, appErr := th.App.GetCPAValue(rctx, propertyValue.ID) + pv, appErr := th.App.GetPropertyValue(rctx, cpaID, propertyValue.ID) require.Nil(t, appErr) require.NotNil(t, pv) var arrayValues []string require.NoError(t, json.Unmarshal(pv.Value, &arrayValues)) - require.Equal(t, []string{"option1", "option2", "option3"}, arrayValues) - }) -} - -func TestListCPAValues(t *testing.T) { - mainHelper.Parallel(t) - th := SetupConfig(t, func(cfg *model.Config) { - cfg.FeatureFlags.CustomProfileAttributes = true - }).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - userID := model.NewId() - - t.Run("should return empty list when user has no values", func(t *testing.T) { - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Empty(t, values) - }) - - t.Run("should list all values for a user", func(t *testing.T) { - var expectedValues []json.RawMessage - - for i := 1; i <= CustomProfileAttributesFieldLimit; i++ { - field := &model.PropertyField{ - GroupID: cpaID, - Name: fmt.Sprintf("Field %d", i), - Type: model.PropertyFieldTypeText, - } - _, err := th.App.CreatePropertyField(rctx, field, false, "") - require.Nil(t, err) - - value := &model.PropertyValue{ - TargetID: userID, - TargetType: model.PropertyValueTargetTypeUser, - GroupID: cpaID, - FieldID: field.ID, - Value: json.RawMessage(fmt.Sprintf(`"Value %d"`, i)), - } - _, err = th.App.CreatePropertyValue(rctx, value) - require.Nil(t, err) - expectedValues = append(expectedValues, value.Value) - } - - // List values for original user - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Len(t, values, CustomProfileAttributesFieldLimit) - - actualValues := make([]json.RawMessage, len(values)) - for i, value := range values { - require.Equal(t, userID, value.TargetID) - require.Equal(t, "user", value.TargetType) - require.Equal(t, cpaID, value.GroupID) - actualValues[i] = value.Value - } - require.ElementsMatch(t, expectedValues, actualValues) - }) -} - -func TestPatchCPAValue(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - t.Run("should fail if the field doesn't exist", func(t *testing.T) { - invalidFieldID := model.NewId() - _, appErr := th.App.PatchCPAValue(rctx, model.NewId(), invalidFieldID, json.RawMessage(`"fieldValue"`), true) - require.NotNil(t, appErr) - }) - - t.Run("should create value if new field value", func(t *testing.T) { - newField := &model.PropertyField{ - GroupID: cpaID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, - } - createdField, err := th.App.CreatePropertyField(rctx, newField, false, "") - require.Nil(t, err) - - userID := model.NewId() - patchedValue, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(`"test value"`), true) - require.Nil(t, appErr) - require.NotNil(t, patchedValue) - require.Equal(t, json.RawMessage(`"test value"`), patchedValue.Value) - require.Equal(t, userID, patchedValue.TargetID) - - t.Run("should correctly patch the CPA property value", func(t *testing.T) { - patch2, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(`"new patched value"`), true) - require.Nil(t, appErr) - require.NotNil(t, patch2) - require.Equal(t, patchedValue.ID, patch2.ID) - require.Equal(t, json.RawMessage(`"new patched value"`), patch2.Value) - require.Equal(t, userID, patch2.TargetID) - }) - }) - - t.Run("should fail if field is deleted", func(t *testing.T) { - newField := &model.PropertyField{ - GroupID: cpaID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, - } - createdField, err := th.App.CreatePropertyField(rctx, newField, false, "") - require.Nil(t, err) - err = th.App.DeletePropertyField(rctx, cpaID, createdField.ID, false, "") - require.Nil(t, err) - - userID := model.NewId() - patchedValue, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(`"test value"`), true) - require.NotNil(t, appErr) - require.Nil(t, patchedValue) - }) - - t.Run("should handle array values correctly", func(t *testing.T) { - optionsID := []string{model.NewId(), model.NewId(), model.NewId(), model.NewId()} - arrayField := &model.PropertyField{ - GroupID: cpaID, - Name: model.NewId(), - Type: model.PropertyFieldTypeMultiselect, - Attrs: model.StringInterface{ - "options": []map[string]any{ - {"id": optionsID[0], "name": "option1"}, - {"id": optionsID[1], "name": "option2"}, - {"id": optionsID[2], "name": "option3"}, - {"id": optionsID[3], "name": "option4"}, - }, - }, - } - createdField, err := th.App.CreatePropertyField(rctx, arrayField, false, "") - require.Nil(t, err) - - // Create a JSON array with option IDs (not names) - optionJSON := fmt.Sprintf(`["%s", "%s", "%s"]`, optionsID[0], optionsID[1], optionsID[2]) - - userID := model.NewId() - patchedValue, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(optionJSON), true) - require.Nil(t, appErr) - require.NotNil(t, patchedValue) - var arrayValues []string - require.NoError(t, json.Unmarshal(patchedValue.Value, &arrayValues)) - require.Equal(t, []string{optionsID[0], optionsID[1], optionsID[2]}, arrayValues) - require.Equal(t, userID, patchedValue.TargetID) - - // Update array values with valid option IDs - updatedOptionJSON := fmt.Sprintf(`["%s", "%s"]`, optionsID[1], optionsID[3]) - updatedValue, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(updatedOptionJSON), true) - require.Nil(t, appErr) - require.NotNil(t, updatedValue) - require.Equal(t, patchedValue.ID, updatedValue.ID) - arrayValues = nil - require.NoError(t, json.Unmarshal(updatedValue.Value, &arrayValues)) - require.Equal(t, []string{optionsID[1], optionsID[3]}, arrayValues) - require.Equal(t, userID, updatedValue.TargetID) - - t.Run("should fail if it tries to set a value that not valid for a field", func(t *testing.T) { - // Try to use an ID that doesn't exist in the options - invalidID := model.NewId() - invalidOptionJSON := fmt.Sprintf(`["%s", "%s"]`, optionsID[0], invalidID) - - invalidValue, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(invalidOptionJSON), true) - require.NotNil(t, appErr) - require.Nil(t, invalidValue) - require.Equal(t, "app.custom_profile_attributes.validate_value.app_error", appErr.Id) - - // Test with completely invalid JSON format - invalidJSON := `[not valid json]` - invalidValue, appErr = th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(invalidJSON), true) - require.NotNil(t, appErr) - require.Nil(t, invalidValue) - require.Equal(t, "app.custom_profile_attributes.validate_value.app_error", appErr.Id) - - // Test with wrong data type (sending string instead of array) - wrongTypeJSON := `"not an array"` - invalidValue, appErr = th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(wrongTypeJSON), true) - require.NotNil(t, appErr) - require.Nil(t, invalidValue) - require.Equal(t, "app.custom_profile_attributes.validate_value.app_error", appErr.Id) - }) + require.Equal(t, optionIDs, arrayValues) }) } @@ -1065,232 +117,176 @@ func TestDeleteCPAValues(t *testing.T) { th := SetupConfig(t, func(cfg *model.Config) { cfg.FeatureFlags.CustomProfileAttributes = true }).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID rctx := th.emptyContextWithCallerID(anonymousCallerId) userID := model.NewId() otherUserID := model.NewId() - // Create multiple fields and values for the user - var createdFields []*model.CPAField - for i := 1; i <= 3; i++ { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - GroupID: cpaID, - Name: fmt.Sprintf("field_%d", i), - Type: model.PropertyFieldTypeText, + listValues := func(targetID string) []*model.PropertyValue { + t.Helper() + values, appErr := th.App.SearchPropertyValues(rctx, cpaID, model.PropertyValueSearchOpts{ + TargetIDs: []string{targetID}, + TargetType: model.PropertyValueTargetTypeUser, + // Single-target search: at most one value per (target, field), so the field cap bounds the page. + PerPage: model.AccessControlGroupFieldLimit + 5, }) - require.NoError(t, err) - createdField, appErr := th.App.CreateCPAField(rctx, field) require.Nil(t, appErr) - createdFields = append(createdFields, createdField) - - // Create a value for this field - value, appErr := th.App.PatchCPAValue(rctx, userID, createdField.ID, json.RawMessage(fmt.Sprintf(`"Value %d"`, i)), false) - require.Nil(t, appErr) - require.NotNil(t, value) + return values } - // Verify values exist before deletion - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Len(t, values, 3) + // Create multiple fields and a value per field for userID. + var createdFields []*model.PropertyField + for i := 1; i <= 3; i++ { + field := &model.PropertyField{ + GroupID: cpaID, + Name: fmt.Sprintf("field_%d", i), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdField, err := th.App.CreatePropertyField(rctx, field, false, "") + require.Nil(t, err) + createdFields = append(createdFields, createdField) + + value := &model.PropertyValue{ + TargetID: userID, + TargetType: model.PropertyValueTargetTypeUser, + GroupID: cpaID, + FieldID: createdField.ID, + Value: json.RawMessage(fmt.Sprintf(`"Value %d"`, i)), + } + _, err = th.App.CreatePropertyValue(rctx, value) + require.Nil(t, err) + } + + require.Len(t, listValues(userID), 3) - // Test deleting values for user t.Run("should delete all values for a user", func(t *testing.T) { - appErr := th.App.DeleteCPAValues(rctx, userID) + appErr := th.App.DeletePropertyValuesForTarget(rctx, cpaID, model.PropertyFieldObjectTypeUser, userID) require.Nil(t, appErr) - // Verify values are gone - values, appErr := th.App.ListCPAValues(rctx, userID) - require.Nil(t, appErr) - require.Empty(t, values) + require.Empty(t, listValues(userID)) }) t.Run("should handle deleting values for a user with no values", func(t *testing.T) { - appErr := th.App.DeleteCPAValues(rctx, otherUserID) + appErr := th.App.DeletePropertyValuesForTarget(rctx, cpaID, model.PropertyFieldObjectTypeUser, otherUserID) require.Nil(t, appErr) }) t.Run("should not affect values for other users", func(t *testing.T) { - // Create values for another user + // Create values for otherUserID. for _, field := range createdFields { - value, appErr := th.App.PatchCPAValue(rctx, otherUserID, field.ID, json.RawMessage(`"Other user value"`), false) - require.Nil(t, appErr) - require.NotNil(t, value) + value := &model.PropertyValue{ + TargetID: otherUserID, + TargetType: model.PropertyValueTargetTypeUser, + GroupID: cpaID, + FieldID: field.ID, + Value: json.RawMessage(`"Other user value"`), + } + _, err := th.App.CreatePropertyValue(rctx, value) + require.Nil(t, err) } - // Delete values for original user - appErr := th.App.DeleteCPAValues(rctx, userID) + appErr := th.App.DeletePropertyValuesForTarget(rctx, cpaID, model.PropertyFieldObjectTypeUser, userID) require.Nil(t, appErr) - // Verify other user's values still exist - values, appErr := th.App.ListCPAValues(rctx, otherUserID) - require.Nil(t, appErr) - require.Len(t, values, 3) + require.Len(t, listValues(otherUserID), 3) }) } -func TestCreateCPAField_RejectsInvalidName(t *testing.T) { +// TestCPAValueSyncLock exercises AccessControlHook.checkSyncLock end-to-end +// at the app layer: a write for a field with ldap= or saml= set only +// succeeds when the caller ID matches the field's sync source. Covering this +// at the app layer also asserts that the startup wiring in server.go +// (access_control group registration, AccessControlHook install, and +// CallerIDExtractor reading from request.CTX) is intact — something the +// properties-package tests cannot verify because they install the hook +// themselves. +func TestCPAValueSyncLock(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) - rctx := th.emptyContextWithCallerID(anonymousCallerId) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID - tests := []struct { - name string - fieldName string - wantErrID string - }{ - { - name: "space in name", - fieldName: "My Field", - wantErrID: "model.cpa_field.name.invalid_charset.app_error", - }, - { - name: "leading digit", - fieldName: "7department", - wantErrID: "model.cpa_field.name.invalid_charset.app_error", - }, - { - name: "reserved word in", - fieldName: "in", - wantErrID: "model.cpa_field.name.reserved_word.app_error", - }, - { - name: "reserved word true", - fieldName: "true", - wantErrID: "model.cpa_field.name.reserved_word.app_error", - }, - } + adminRctx := th.emptyContextWithCallerID(th.SystemAdminUser.Id) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: tt.fieldName, - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) - - _, appErr := th.App.CreateCPAField(rctx, field) - require.NotNil(t, appErr, "expected error for name %q", tt.fieldName) - require.Equal(t, tt.wantErrID, appErr.Id) - }) - } -} - -func TestCreateCPAField_AcceptsValidName(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - validNames := []string{"department", "_private", "A1", "a_b_c", "Department", "DEPT"} - for _, n := range validNames { - t.Run(n, func(t *testing.T) { - field, err := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: n, - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, err) - - created, appErr := th.App.CreateCPAField(rctx, field) - require.Nil(t, appErr, "unexpected error for name %q: %v", n, appErr) - require.NotEmpty(t, created.ID) - - _ = th.App.DeleteCPAField(rctx, created.ID) - }) - } -} - -func TestPatchCPAField_GrandfatherSkipsValidationOnUnchangedName(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - // Seed a field with an invalid CPA name directly via CreatePropertyField (bypassing CPA validation). - // This simulates a pre-existing legacy field whose name violates the new CEL rule. - legacyField, err := th.App.CreatePropertyField(rctx, &model.PropertyField{ - GroupID: cpaID, - Name: "My Legacy Field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsVisibility: model.CustomProfileAttributesVisibilityWhenSet}, - }, false, "") - require.Nil(t, err) - defer func() { _ = th.App.DeleteCPAField(rctx, legacyField.ID) }() - - t.Run("patching only visibility leaves invalid name unchanged (grandfather passes)", func(t *testing.T) { - newVisibility := model.CustomProfileAttributesVisibilityAlways - patch := &model.PropertyFieldPatch{ - Attrs: &model.StringInterface{ - model.CustomProfileAttributesPropertyAttrsVisibility: newVisibility, + createField := func(name string, attrs model.CPAAttrs) *model.PropertyField { + t.Helper() + cpa := &model.CPAField{ + PropertyField: model.PropertyField{ + GroupID: cpaID, + Name: name, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), }, + Attrs: attrs, } - patched, appErr := th.App.PatchCPAField(rctx, legacyField.ID, patch) - require.Nil(t, appErr, "grandfather: patching non-name attrs on a legacy field must not trigger validation") - require.Equal(t, "My Legacy Field", patched.Name, "name must remain unchanged") - require.Equal(t, newVisibility, patched.Attrs.Visibility) - }) - - t.Run("patching name to another invalid value returns validation error", func(t *testing.T) { - stillInvalidName := "still invalid name" - patch := &model.PropertyFieldPatch{ - Name: new(stillInvalidName), - } - _, appErr := th.App.PatchCPAField(rctx, legacyField.ID, patch) - require.NotNil(t, appErr, "renaming to an invalid name must be rejected") - require.Equal(t, "model.cpa_field.name.invalid_charset.app_error", appErr.Id) - }) - - t.Run("patching name to a valid value succeeds", func(t *testing.T) { - validName := "my_legacy_field" - patch := &model.PropertyFieldPatch{ - Name: new(validName), - } - patched, appErr := th.App.PatchCPAField(rctx, legacyField.ID, patch) - require.Nil(t, appErr, "renaming to a valid CEL identifier must succeed") - require.Equal(t, validName, patched.Name) - }) -} - -// TestCreatePropertyField_BypassesCPANameValidation_ExpectedBehavior asserts the documented -// Option C bypass: the generic property-field App API does NOT enforce the CPA name regex -// on master. This is intentional and time-bounded. -// -// PR #36173's AttributeValidationHook will close the bypass at the property-service layer. -// Do NOT "fix" this test by adding CPA name validation in App.CreatePropertyField ahead of -// #36173 landing — doing so would conflict with @davidkrauser's diff. -// -// See spec.md §Out of Scope and the CPAAttrs godoc block in -// server/public/model/custom_profile_attributes.go (§Non-enforcement) for full context. -func TestCreatePropertyField_BypassesCPANameValidation_ExpectedBehavior(t *testing.T) { - mainHelper.Parallel(t) - th := Setup(t).InitBasic(t) - - cpaID, cErr := th.App.CpaGroupID() - require.Nil(t, cErr) - - rctx := th.emptyContextWithCallerID(anonymousCallerId) - - // "My Field" violates CPAFieldNamePattern — would be rejected by CreateCPAField. - // Via CreatePropertyField (the generic property API), it must succeed. - field := &model.PropertyField{ - GroupID: cpaID, - Name: "My Field", - Type: model.PropertyFieldTypeText, + // Sanitization/validation runs inside CreatePropertyField via the + // AccessControlAttributeValidationHook. + created, appErr := th.App.CreatePropertyField(adminRctx, cpa.ToPropertyField(), false, "") + require.Nil(t, appErr) + return created } - created, appErr := th.App.CreatePropertyField(rctx, field, false, "") - require.Nil(t, appErr, - "CreatePropertyField must NOT enforce the CPA name regex on master — "+ - "that enforcement belongs to PR #36173's AttributeValidationHook") - require.NotEmpty(t, created.ID) + ldapField := createField("ldap_attr_"+model.NewId(), model.CPAAttrs{LDAP: "mail"}) + samlField := createField("saml_attr_"+model.NewId(), model.CPAAttrs{SAML: "email"}) + plainField := createField("plain_attr_"+model.NewId(), model.CPAAttrs{}) - _ = th.App.DeleteCPAField(rctx, created.ID) + userID := model.NewId() + upsertAs := func(callerID string, field *model.PropertyField) *model.AppError { + t.Helper() + rctx := th.emptyContextWithCallerID(callerID) + _, appErr := th.App.UpsertPropertyValues(rctx, []*model.PropertyValue{{ + GroupID: cpaID, + TargetType: model.PropertyValueTargetTypeUser, + TargetID: userID, + FieldID: field.ID, + Value: json.RawMessage(`"value"`), + }}, model.PropertyFieldObjectTypeUser, userID, "") + return appErr + } + + requireSyncLock := func(appErr *model.AppError) { + t.Helper() + require.NotNil(t, appErr) + require.Equal(t, "app.property.sync_lock.app_error", appErr.Id) + } + + t.Run("anonymous caller is blocked on an LDAP-synced field", func(t *testing.T) { + requireSyncLock(upsertAs(anonymousCallerId, ldapField)) + }) + + t.Run("anonymous caller is blocked on a SAML-synced field", func(t *testing.T) { + requireSyncLock(upsertAs(anonymousCallerId, samlField)) + }) + + t.Run("anonymous caller is allowed on a non-synced field", func(t *testing.T) { + require.Nil(t, upsertAs(anonymousCallerId, plainField)) + }) + + t.Run("LDAP sync caller is allowed on an LDAP-synced field", func(t *testing.T) { + require.Nil(t, upsertAs(model.CallerIDLDAPSync, ldapField)) + }) + + t.Run("LDAP sync caller is blocked on a SAML-synced field", func(t *testing.T) { + requireSyncLock(upsertAs(model.CallerIDLDAPSync, samlField)) + }) + + t.Run("SAML sync caller is allowed on a SAML-synced field", func(t *testing.T) { + require.Nil(t, upsertAs(model.CallerIDSAMLSync, samlField)) + }) + + t.Run("SAML sync caller is blocked on an LDAP-synced field", func(t *testing.T) { + requireSyncLock(upsertAs(model.CallerIDSAMLSync, ldapField)) + }) } diff --git a/server/channels/app/migrations.go b/server/channels/app/migrations.go index 524d54ba784..2033abf976c 100644 --- a/server/channels/app/migrations.go +++ b/server/channels/app/migrations.go @@ -753,7 +753,7 @@ func (s *Server) doSetupContentFlaggingProperties() error { } if len(propertiesToUpdate) > 0 { - if _, _, err := s.propertyService.UpdatePropertyFields(nil, group.ID, propertiesToUpdate); err != nil { + if _, _, _, err := s.propertyService.UpdatePropertyFields(nil, group.ID, propertiesToUpdate); err != nil { // Another server may have won the race and updated these fields // concurrently (e.g. parallel tests sharing a database pool). // Both servers write the same expected values, so tolerate the @@ -854,7 +854,7 @@ func (s *Server) doSetupBoardsProperties() error { } if len(propertiesToUpdate) > 0 { - if _, _, err := s.propertyService.UpdatePropertyFields(nil, group.ID, propertiesToUpdate); err != nil { + if _, _, _, err := s.propertyService.UpdatePropertyFields(nil, group.ID, propertiesToUpdate); err != nil { // Another server may have won the race and updated these fields // concurrently (e.g. parallel tests sharing a database pool). // Both servers write the same expected values, so tolerate the diff --git a/server/channels/app/migrations_test.go b/server/channels/app/migrations_test.go index 91b4e823adf..a80a455a3a2 100644 --- a/server/channels/app/migrations_test.go +++ b/server/channels/app/migrations_test.go @@ -190,41 +190,54 @@ func TestCPADisplayNameBackfill_NoExistingFields(t *testing.T) { func TestCPADisplayNameBackfill_BackfillsMissing(t *testing.T) { th := Setup(t) + // LicenseCheckHook gates writes to the access_control group on an + // Enterprise license; the seed CreatePropertyField calls below would + // otherwise be rejected with app.property.license_error. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) clearCPABackfillMarker(t, th) - // fieldA exercises the "display_name present as empty string in JSONB" case — the true - // idempotency boundary. - fieldABase, convErr := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: "department", - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, convErr) - fieldA, appErr := th.App.CreateCPAField(th.Context, fieldABase) + group, appErr := th.App.GetPropertyGroup(th.Context, model.AccessControlPropertyGroupName) require.Nil(t, appErr) - require.Equal(t, "", fieldA.Attrs.DisplayName, "seed invariant: fieldA must have empty display_name") - fieldBBase, convErr := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: "job_title", - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, convErr) - fieldBBase.Attrs.DisplayName = "Job Title" - fieldB, appErr := th.App.CreateCPAField(th.Context, fieldBBase) + // fieldA exercises the "display_name absent / empty in JSONB" case — the + // true idempotency boundary the migration is designed to fix. + fieldA, appErr := th.App.CreatePropertyField(th.Context, &model.PropertyField{ + GroupID: group.ID, + Name: "department", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + }, false, "") require.Nil(t, appErr) - require.Equal(t, "Job Title", fieldB.Attrs.DisplayName, "seed invariant: fieldB must have display_name set") + require.Empty(t, fieldA.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName], + "seed invariant: fieldA must have empty display_name") + + fieldB, appErr := th.App.CreatePropertyField(th.Context, &model.PropertyField{ + GroupID: group.ID, + Name: "job_title", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsDisplayName: "Job Title", + }, + }, false, "") + require.Nil(t, appErr) + require.Equal(t, "Job Title", fieldB.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName], + "seed invariant: fieldB must have display_name set") err := th.Server.doSetupCPADisplayNameBackfill(th.Context) require.NoError(t, err) - updatedFieldA, appErr := th.App.GetCPAField(th.Context, fieldA.ID) + updatedFieldA, appErr := th.App.GetPropertyField(th.Context, group.ID, fieldA.ID) require.Nil(t, appErr) - require.Equal(t, "department", updatedFieldA.Attrs.DisplayName, + require.Equal(t, "department", updatedFieldA.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName], "fieldA: display_name must be backfilled to field name") - updatedFieldB, appErr := th.App.GetCPAField(th.Context, fieldB.ID) + updatedFieldB, appErr := th.App.GetPropertyField(th.Context, group.ID, fieldB.ID) require.Nil(t, appErr) - require.Equal(t, "Job Title", updatedFieldB.Attrs.DisplayName, + require.Equal(t, "Job Title", updatedFieldB.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName], "fieldB: display_name must not be overwritten when already set") data, sysErr := th.Store.System().GetByName(cpaDisplayNameBackfillKey) @@ -235,15 +248,23 @@ func TestCPADisplayNameBackfill_BackfillsMissing(t *testing.T) { func TestCPADisplayNameBackfill_Idempotent(t *testing.T) { th := Setup(t) + // LicenseCheckHook gates writes to the access_control group on an + // Enterprise license; the seed CreatePropertyField call below would + // otherwise be rejected with app.property.license_error. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) clearCPABackfillMarker(t, th) - fieldBase, convErr := model.NewCPAFieldFromPropertyField(&model.PropertyField{ - Name: "location", - Type: model.PropertyFieldTypeText, - }) - require.NoError(t, convErr) - seeded, appErr := th.App.CreateCPAField(th.Context, fieldBase) + group, appErr := th.App.GetPropertyGroup(th.Context, model.AccessControlPropertyGroupName) + require.Nil(t, appErr) + + seeded, appErr := th.App.CreatePropertyField(th.Context, &model.PropertyField{ + GroupID: group.ID, + Name: "location", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + }, false, "") require.Nil(t, appErr) err := th.Server.doSetupCPADisplayNameBackfill(th.Context) @@ -253,9 +274,9 @@ func TestCPADisplayNameBackfill_Idempotent(t *testing.T) { require.NoError(t, sysErr) require.Equal(t, "true", data1.Value) - updatedAfterFirst, appErr := th.App.GetCPAField(th.Context, seeded.ID) + updatedAfterFirst, appErr := th.App.GetPropertyField(th.Context, group.ID, seeded.ID) require.Nil(t, appErr) - require.Equal(t, "location", updatedAfterFirst.Attrs.DisplayName) + require.Equal(t, "location", updatedAfterFirst.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName]) // Snapshot UpdateAt before the second run so we can prove the second run is a no-op // at the DB-write level. PropertyField.UpdateAt is set to model.GetMillis() on every @@ -272,9 +293,9 @@ func TestCPADisplayNameBackfill_Idempotent(t *testing.T) { require.NoError(t, sysErr) require.Equal(t, "true", data2.Value) - updatedAfterSecond, appErr := th.App.GetCPAField(th.Context, seeded.ID) + updatedAfterSecond, appErr := th.App.GetPropertyField(th.Context, group.ID, seeded.ID) require.Nil(t, appErr) - require.Equal(t, "location", updatedAfterSecond.Attrs.DisplayName, + require.Equal(t, "location", updatedAfterSecond.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName], "second run must not change display_name") require.Equal(t, firstFieldUpdate, updatedAfterSecond.UpdateAt, @@ -283,21 +304,30 @@ func TestCPADisplayNameBackfill_Idempotent(t *testing.T) { func TestCPADisplayNameBackfill_BackfillsProtectedSourceOnlyField(t *testing.T) { th := Setup(t) + // LicenseCheckHook gates writes to the access_control group on an + // Enterprise license. The seed below bypasses Create-side hooks via a + // direct store insert, but the backfill migration calls UpdatePropertyFields + // (unhooked) which still runs the version-match check; the license is + // nevertheless required by other CPA paths exercised across the suite. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) clearCPABackfillMarker(t, th) - groupID, appErr := th.App.CpaGroupID() + group, appErr := th.App.GetPropertyGroup(th.Context, model.AccessControlPropertyGroupName) require.Nil(t, appErr) + groupID := group.ID // Insert directly via the store so we bypass the property service's // access-control routing (which would reject creating a protected - // source_only field from a non-plugin caller). Type=text avoids the - // options-stripping branch in read access control, but the migration's - // correctness here doesn't depend on the field type. + // source_only field from a non-plugin caller). ObjectType/TargetType are + // required so the field is recognized as PSAv2 and matches the group's + // version when the migration's UpdatePropertyFields runs. field := &model.PropertyField{ - GroupID: groupID, - Name: "uas_employee_id", - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: "uas_employee_id", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, diff --git a/server/channels/app/plugin_api.go b/server/channels/app/plugin_api.go index 3ff20d667c4..fa6dad2bec4 100644 --- a/server/channels/app/plugin_api.go +++ b/server/channels/app/plugin_api.go @@ -1596,7 +1596,7 @@ func (api *PluginAPI) GetPropertyFields(groupID string, ids []string) ([]*model. } func (api *PluginAPI) UpdatePropertyField(groupID string, field *model.PropertyField) (*model.PropertyField, error) { - updatedField, appErr := api.app.UpdatePropertyField(api.psaPluginContext(), groupID, field, false, "") + updatedField, _, appErr := api.app.UpdatePropertyField(api.psaPluginContext(), groupID, field, false, "") if appErr != nil { return nil, appErr } @@ -1690,6 +1690,14 @@ func (api *PluginAPI) SearchPropertyValues(groupID string, opts model.PropertyVa } func (api *PluginAPI) RegisterPropertyGroup(name string) (*model.PropertyGroup, error) { + if name == model.DeprecatedCPAPropertyGroupName { + return nil, fmt.Errorf( + "the group name %q has been renamed to %q; use %q instead", + model.DeprecatedCPAPropertyGroupName, + model.AccessControlPropertyGroupName, + model.AccessControlPropertyGroupName, + ) + } group, appErr := api.app.RegisterPropertyGroup(api.psaPluginContext(), &model.PropertyGroup{ Name: name, Version: model.PropertyGroupVersionV1, @@ -1701,6 +1709,7 @@ func (api *PluginAPI) RegisterPropertyGroup(name string) (*model.PropertyGroup, } func (api *PluginAPI) GetPropertyGroup(name string) (*model.PropertyGroup, error) { + name = migrateDeprecatedPropertyGroupName(name) group, appErr := api.app.GetPropertyGroup(api.psaPluginContext(), name) if appErr != nil { return nil, appErr @@ -1708,6 +1717,15 @@ func (api *PluginAPI) GetPropertyGroup(name string) (*model.PropertyGroup, error return group, nil } +// migrateDeprecatedPropertyGroupName maps the deprecated "custom_profile_attributes" +// group name to the current "access_control" name for backward compatibility. +func migrateDeprecatedPropertyGroupName(name string) string { + if name == model.DeprecatedCPAPropertyGroupName { + return model.AccessControlPropertyGroupName + } + return name +} + func (api *PluginAPI) GetPropertyFieldByName(groupID, targetID, name string) (*model.PropertyField, error) { field, appErr := api.app.GetPropertyFieldByName(api.psaPluginContext(), groupID, targetID, name) if appErr != nil { @@ -1717,7 +1735,7 @@ func (api *PluginAPI) GetPropertyFieldByName(groupID, targetID, name string) (*m } func (api *PluginAPI) UpdatePropertyFields(groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { - updatedFields, appErr := api.app.UpdatePropertyFields(api.psaPluginContext(), groupID, fields, false, "") + updatedFields, _, appErr := api.app.UpdatePropertyFields(api.psaPluginContext(), groupID, fields, false, "") if appErr != nil { return nil, appErr } diff --git a/server/channels/app/plugin_api_test.go b/server/channels/app/plugin_api_test.go index 4421513bb31..b8a66f15f5a 100644 --- a/server/channels/app/plugin_api_test.go +++ b/server/channels/app/plugin_api_test.go @@ -3864,3 +3864,70 @@ func TestPluginAPICreateChannelAnonymousURLs(t *testing.T) { assert.Equal(t, originalName, createdChannel.Name, "channel name should not be overridden") }) } + +func TestPluginAPIPropertyGroupDeprecatedName(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("RegisterPropertyGroup rejects deprecated name", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + api := th.SetupPluginAPI() + + // Register using the deprecated name must fail + _, err := api.RegisterPropertyGroup(model.DeprecatedCPAPropertyGroupName) + require.Error(t, err) + assert.Contains(t, err.Error(), "renamed") + + // Register using the canonical name should still work + group, err := api.RegisterPropertyGroup(model.AccessControlPropertyGroupName) + require.NoError(t, err) + require.NotNil(t, group) + assert.Equal(t, model.AccessControlPropertyGroupName, group.Name) + }) + + t.Run("GetPropertyGroup maps deprecated name to canonical name", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + api := th.SetupPluginAPI() + + // The access_control group is registered at server startup, so + // we can look it up directly. + canonical, err := api.GetPropertyGroup(model.AccessControlPropertyGroupName) + require.NoError(t, err) + require.NotNil(t, canonical) + + // Looking up by the deprecated name should return the same group + deprecated, err := api.GetPropertyGroup(model.DeprecatedCPAPropertyGroupName) + require.NoError(t, err) + require.NotNil(t, deprecated) + + assert.Equal(t, canonical.ID, deprecated.ID) + assert.Equal(t, model.AccessControlPropertyGroupName, deprecated.Name) + }) + + t.Run("other group names are not affected by the mapping", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + api := th.SetupPluginAPI() + + // Register a different group — no mapping should occur + group, err := api.RegisterPropertyGroup("my_plugin_group") + require.NoError(t, err) + require.NotNil(t, group) + assert.Equal(t, "my_plugin_group", group.Name) + + // Look it up + fetched, err := api.GetPropertyGroup("my_plugin_group") + require.NoError(t, err) + assert.Equal(t, group.ID, fetched.ID) + }) + + t.Run("GetPropertyGroup with nonexistent name returns error", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + api := th.SetupPluginAPI() + + _, err := api.GetPropertyGroup("no_such_group") + require.Error(t, err) + }) +} diff --git a/server/channels/app/plugin_properties_test.go b/server/channels/app/plugin_properties_test.go index d945e802823..67e742be59f 100644 --- a/server/channels/app/plugin_properties_test.go +++ b/server/channels/app/plugin_properties_test.go @@ -9,14 +9,16 @@ import ( "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" ) // cleanupCPAFields deletes all existing CPA fields to ensure a clean state func cleanupCPAFields(t *testing.T, th *TestHelper) { t.Helper() - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID fields, searchErr := th.App.Srv().Store().PropertyField().SearchPropertyFields(model.PropertyFieldSearchOpts{ GroupID: cpaID, @@ -33,6 +35,11 @@ func cleanupCPAFields(t *testing.T, th *TestHelper) { func TestPluginProperties(t *testing.T) { th := Setup(t).InitBasic(t) + // Subtests that exercise the access_control group require an + // Enterprise license because LicenseCheckHook gates that group. + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + t.Cleanup(func() { _ = th.App.Srv().RemoveLicense() }) + t.Run("test property field methods", func(t *testing.T) { groupName := model.NewId() tearDown, pluginIDs, activationErrors := SetAppEnvironmentWithPlugins(t, []string{` @@ -457,8 +464,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin-created CPA field gets source_plugin_id", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID tearDown, pluginIDs, activationErrors := SetAppEnvironmentWithPlugins(t, []string{` package main @@ -476,9 +484,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { // Create a CPA field field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "CPA Test Field", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "cpa_test_field", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } createdField, err := p.API.CreatePropertyField(field) @@ -521,8 +531,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin can update its own protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID tearDown, pluginIDs, activationErrors := SetAppEnvironmentWithPlugins(t, []string{` package main @@ -540,9 +551,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { // Create a protected CPA field field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Protected Field", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "protected_field", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -554,13 +567,13 @@ func TestPluginProperties(t *testing.T) { } // Try to update the protected field (should succeed since we created it) - createdField.Name = "Updated Protected Field" + createdField.Name = "updated_protected_field" updatedField, err := p.API.UpdatePropertyField("` + cpaID + `", createdField) if err != nil { return fmt.Errorf("failed to update own protected field: %w", err) } - if updatedField.Name != "Updated Protected Field" { + if updatedField.Name != "updated_protected_field" { return fmt.Errorf("field name not updated correctly") } @@ -585,8 +598,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin cannot update another plugin's protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID // Both plugins in same environment tearDown, _, activationErrors := SetAppEnvironmentWithPlugins(t, []string{ @@ -607,9 +621,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { // Create a protected CPA field field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Plugin1 Protected Field", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "plugin1_protected_field", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -650,7 +666,7 @@ func TestPluginProperties(t *testing.T) { var plugin1Field *model.PropertyField for _, field := range fields { - if field.Name == "Plugin1 Protected Field" { + if field.Name == "plugin1_protected_field" { plugin1Field = field break } @@ -685,8 +701,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin can delete its own protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID tearDown, pluginIDs, activationErrors := SetAppEnvironmentWithPlugins(t, []string{` package main @@ -704,9 +721,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { // Create a protected CPA field field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Field To Delete", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "field_to_delete", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -744,8 +763,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin cannot delete another plugin's protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID // Both plugins in same environment tearDown, _, activationErrors := SetAppEnvironmentWithPlugins(t, []string{ @@ -765,9 +785,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Plugin1 Field To Keep", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "plugin1_field_to_keep", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -808,7 +830,7 @@ func TestPluginProperties(t *testing.T) { var plugin1Field *model.PropertyField for _, field := range fields { - if field.Name == "Plugin1 Field To Keep" { + if field.Name == "plugin1_field_to_keep" { plugin1Field = field break } @@ -842,8 +864,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin can update values for its own protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID tearDown, pluginIDs, activationErrors := SetAppEnvironmentWithPlugins(t, []string{` package main @@ -861,9 +884,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { // Create a protected CPA field field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Protected Field With Values", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "protected_field_with_values", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -921,8 +946,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin cannot update values for another plugin's protected field", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID testTargetID := model.NewId() @@ -944,9 +970,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Plugin1 Field With Protected Values", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "plugin1_field_with_protected_values", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: map[string]any{ "protected": true, }, @@ -1001,7 +1029,7 @@ func TestPluginProperties(t *testing.T) { var plugin1Field *model.PropertyField for _, field := range fields { - if field.Name == "Plugin1 Field With Protected Values" { + if field.Name == "plugin1_field_with_protected_values" { plugin1Field = field break } @@ -1043,8 +1071,9 @@ func TestPluginProperties(t *testing.T) { t.Run("test plugin can modify non-protected CPA fields from other plugins", func(t *testing.T) { cleanupCPAFields(t, th) - cpaID, err := th.App.CpaGroupID() - require.Nil(t, err) + cpaGroup, groupErr := th.App.GetPropertyGroup(request.TestContext(t), model.AccessControlPropertyGroupName) + require.Nil(t, groupErr) + cpaID := cpaGroup.ID // Both plugins in same environment tearDown, _, activationErrors := SetAppEnvironmentWithPlugins(t, []string{ @@ -1064,9 +1093,11 @@ func TestPluginProperties(t *testing.T) { func (p *MyPlugin) OnActivate() error { field := &model.PropertyField{ - GroupID: "` + cpaID + `", - Name: "Non-Protected Field", - Type: model.PropertyFieldTypeText, + GroupID: "` + cpaID + `", + Name: "non_protected_field", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), // Note: protected is not set } @@ -1105,7 +1136,7 @@ func TestPluginProperties(t *testing.T) { var plugin1Field *model.PropertyField for _, field := range fields { - if field.Name == "Non-Protected Field" { + if field.Name == "non_protected_field" { plugin1Field = field break } @@ -1116,7 +1147,7 @@ func TestPluginProperties(t *testing.T) { } // Update it (should succeed since it's not protected) - plugin1Field.Name = "Modified By Plugin2" + plugin1Field.Name = "modified_by_plugin2" _, err = p.API.UpdatePropertyField("` + cpaID + `", plugin1Field) if err != nil { return fmt.Errorf("failed to update non-protected field: %w", err) @@ -1136,12 +1167,15 @@ func TestPluginProperties(t *testing.T) { require.NoError(t, activationErrors[1]) // Verify the field was actually updated - rctx := th.emptyContextWithCallerID(anonymousCallerId) - updatedFields, appErr := th.App.ListCPAFields(rctx) + updatedFields, appErr := th.App.SearchPropertyFields(request.TestContext(t), cpaID, model.PropertyFieldSearchOpts{ + GroupID: cpaID, + ObjectType: model.PropertyFieldObjectTypeUser, + PerPage: model.AccessControlGroupFieldLimit + 5, + }) require.Nil(t, appErr) var fieldWasUpdated bool for _, field := range updatedFields { - if field.Name == "Modified By Plugin2" { + if field.Name == "modified_by_plugin2" { fieldWasUpdated = true break } diff --git a/server/channels/app/properties/access_control.go b/server/channels/app/properties/access_control.go index 63fd0f6b608..944644a0ceb 100644 --- a/server/channels/app/properties/access_control.go +++ b/server/channels/app/properties/access_control.go @@ -21,18 +21,25 @@ package properties import ( "bytes" "encoding/json" + "errors" "fmt" "maps" "net/http" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store" ) +var ( + ErrAccessDenied = errors.New("access denied") + ErrSyncLocked = errors.New("field is managed by external sync") + ErrInvalidAccessMode = errors.New("invalid access_mode") + ErrFieldNotFound = errors.New("property field not found") +) + const ( - // propertyAccessPaginationPageSize is the default page size for pagination when fetching property values - propertyAccessPaginationPageSize = 100 - // propertyAccessMaxPaginationIterations is the maximum number of pagination iterations before returning an error + propertyAccessPaginationPageSize = 100 propertyAccessMaxPaginationIterations = 10 ) @@ -40,87 +47,93 @@ const ( // Returns true if the plugin exists and is installed, false otherwise. type PluginChecker func(pluginID string) bool -// PropertyAccessService is a layer around PropertyService that enforces access -// control based on caller identity. All property operations go through this -// service to ensure consistent access control enforcement. -type PropertyAccessService struct { +// AccessControlHook implements the PropertyHook interface to enforce access +// control based on caller identity. It checks protected fields, plugin +// ownership, and access modes (public, source-only, shared-only). +// +// The hook only applies to groups whose IDs are in managedGroupIDs. Operations +// on other groups pass through without access control checks. +type AccessControlHook struct { propertyService *PropertyService pluginChecker PluginChecker + managedGroupIDs map[string]struct{} } -// NewPropertyAccessService creates a new PropertyAccessService. -// It receives the PropertyService to call private methods for database operations. -// The pluginChecker function is used to verify plugin installation status when checking access -// to protected fields. Pass nil if plugin checking is not needed (e.g., in tests). -func NewPropertyAccessService(ps *PropertyService, pluginChecker PluginChecker) *PropertyAccessService { - return &PropertyAccessService{ +// Compile-time check that AccessControlHook implements PropertyHook. +var _ PropertyHook = (*AccessControlHook)(nil) + +// NewAccessControlHook creates a new AccessControlHook. +// It receives the PropertyService to call private methods for database lookups +// needed during access control checks. The pluginChecker function is used to +// verify plugin installation status when checking access to protected fields. +// Pass nil for pluginChecker if plugin checking is not needed (e.g., in tests). +// managedGroupIDs lists the property group IDs that this hook enforces access +// control for. Operations on groups not in this list are passed through. +func NewAccessControlHook(ps *PropertyService, pluginChecker PluginChecker, managedGroupIDs ...string) *AccessControlHook { + ids := make(map[string]struct{}, len(managedGroupIDs)) + for _, id := range managedGroupIDs { + ids[id] = struct{}{} + } + return &AccessControlHook{ propertyService: ps, pluginChecker: pluginChecker, + managedGroupIDs: ids, } } -func (pas *PropertyAccessService) setPluginCheckerForTests(pluginChecker PluginChecker) { - pas.pluginChecker = pluginChecker +// isGroupManaged checks whether the given group ID is managed by this hook. +func (h *AccessControlHook) isGroupManaged(groupID string) bool { + _, ok := h.managedGroupIDs[groupID] + return ok } -// Property Field Methods +// Field Pre-Hooks -// isCallerPlugin checks whether the callerID corresponds to an installed plugin. -func (pas *PropertyAccessService) isCallerPlugin(callerID string) bool { - return callerID != "" && pas.pluginChecker != nil && pas.pluginChecker(callerID) -} - -// CreatePropertyField creates a new property field with access control. -// When the caller is an installed plugin, source_plugin_id is automatically set -// to the callerID and the protected attribute is allowed. -// When the caller is not a plugin, source_plugin_id and protected are rejected -// to prevent unauthorized field ownership claims. +// PreCreatePropertyField enforces access control on field creation. +// When the caller is an installed plugin, source_plugin_id is automatically set. +// When the caller is not a plugin, source_plugin_id and protected are rejected. // When linking to a source template, security attributes are validated and // inherited from the source. -func (pas *PropertyAccessService) CreatePropertyField(callerID string, field *model.PropertyField) (*model.PropertyField, error) { - if pas.isCallerPlugin(callerID) { - // Caller is a plugin — auto-set source_plugin_id +func (h *AccessControlHook) PreCreatePropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if !h.isGroupManaged(field.GroupID) { + return field, nil + } + + callerID := h.extractCallerID(rctx) + + if h.isCallerPlugin(callerID) { if field.Attrs == nil { field.Attrs = make(model.StringInterface) } field.Attrs[model.PropertyAttrsSourcePluginID] = callerID } else { - // Non-plugin caller — reject source_plugin_id and protected - if pas.getSourcePluginID(field) != "" { - return nil, fmt.Errorf("CreatePropertyField: source_plugin_id can only be set by a plugin") + if h.getSourcePluginID(field) != "" { + return nil, fmt.Errorf("source_plugin_id can only be set by a plugin: %w", ErrAccessDenied) } if model.IsPropertyFieldProtected(field) { - return nil, fmt.Errorf("CreatePropertyField: protected can only be set by a plugin") + return nil, fmt.Errorf("protected can only be set by a plugin: %w", ErrAccessDenied) } } - // If linking to a source, validate and inherit security attributes if field.LinkedFieldID != nil && *field.LinkedFieldID != "" { - if err := pas.validateAndInheritLinkedFieldSecurity(callerID, field); err != nil { - return nil, fmt.Errorf("CreatePropertyField: %w", err) + if err := h.validateAndInheritLinkedFieldSecurity(callerID, field); err != nil { + return nil, fmt.Errorf("PreCreatePropertyField: %w", err) } } - // Validate access mode (after inheritance so protected flag is correct) if err := model.ValidatePropertyFieldAccessMode(field); err != nil { - return nil, fmt.Errorf("CreatePropertyField: %w", err) + return nil, fmt.Errorf("%s: %w", err.Error(), ErrInvalidAccessMode) } - result, err := pas.propertyService.createPropertyField(field) - if err != nil { - return nil, fmt.Errorf("CreatePropertyField: %w", err) - } - return result, nil + return field, nil } -// validateAndInheritLinkedFieldSecurity enforces that linked fields inherit the -// source template's security posture. If the source is protected, only the -// source plugin may create linked fields. The linked field's access_mode must -// match the source's — divergence is rejected to avoid a false sense of -// security (callers can always inspect the template directly). -// Inherits: Attrs[protected], Attrs[source_plugin_id], Attrs[access_mode]. -func (pas *PropertyAccessService) validateAndInheritLinkedFieldSecurity(callerID string, field *model.PropertyField) error { - source, err := pas.propertyService.getPropertyFieldFromMaster("", *field.LinkedFieldID) +// validateAndInheritLinkedFieldSecurity enforces that linked fields inherit +// the source template's security posture. If the source is protected, only +// the source plugin may create linked fields. Security attrs (protected, +// source_plugin_id, access_mode) are copied from the source onto the field. +func (h *AccessControlHook) validateAndInheritLinkedFieldSecurity(callerID string, field *model.PropertyField) error { + source, err := h.propertyService.getPropertyFieldFromMaster("", *field.LinkedFieldID) if err != nil { if store.IsErrNotFound(err) { return model.NewAppError( @@ -138,7 +151,7 @@ func (pas *PropertyAccessService) validateAndInheritLinkedFieldSecurity(callerID return nil } - sourcePluginID := pas.getSourcePluginID(source) + sourcePluginID := h.getSourcePluginID(source) if sourcePluginID == "" || callerID != sourcePluginID { return model.NewAppError( "CreatePropertyField", @@ -162,428 +175,318 @@ func (pas *PropertyAccessService) validateAndInheritLinkedFieldSecurity(callerID return nil } -// GetPropertyField retrieves a property field by group and field ID. -// Field details are filtered based on the caller's access permissions. -func (pas *PropertyAccessService) GetPropertyField(callerID string, groupID, id string) (*model.PropertyField, error) { - field, err := pas.propertyService.getPropertyField(groupID, id) - if err != nil { - return nil, fmt.Errorf("GetPropertyField: %w", err) - } - - return pas.applyFieldReadAccessControl(field, callerID), nil -} - -// GetPropertyFields retrieves multiple property fields by their IDs. -// Field details are filtered based on the caller's access permissions. -func (pas *PropertyAccessService) GetPropertyFields(callerID string, groupID string, ids []string) ([]*model.PropertyField, error) { - fields, err := pas.propertyService.getPropertyFields(groupID, ids) - if err != nil { - return nil, fmt.Errorf("GetPropertyFields: %w", err) - } - - return pas.applyFieldReadAccessControlToList(fields, callerID), nil -} - -// GetPropertyFieldByName retrieves a property field by name. -// Field details are filtered based on the caller's access permissions. -func (pas *PropertyAccessService) GetPropertyFieldByName(callerID string, groupID, targetID, name string) (*model.PropertyField, error) { - field, err := pas.propertyService.getPropertyFieldByName(groupID, targetID, name) - if err != nil { - return nil, fmt.Errorf("GetPropertyFieldByName: %w", err) - } - - return pas.applyFieldReadAccessControl(field, callerID), nil -} - -// CountActivePropertyFieldsForGroup counts active property fields for a group. -func (pas *PropertyAccessService) CountActivePropertyFieldsForGroup(groupID string) (int64, error) { - return pas.propertyService.countActivePropertyFieldsForGroup(groupID) -} - -// CountAllPropertyFieldsForGroup counts all property fields (including deleted) for a group. -func (pas *PropertyAccessService) CountAllPropertyFieldsForGroup(groupID string) (int64, error) { - return pas.propertyService.countAllPropertyFieldsForGroup(groupID) -} - -// CountActivePropertyFieldsForTarget counts active property fields for a specific target. -func (pas *PropertyAccessService) CountActivePropertyFieldsForTarget(groupID, targetType, targetID string) (int64, error) { - return pas.propertyService.countActivePropertyFieldsForTarget(groupID, targetType, targetID) -} - -// CountAllPropertyFieldsForTarget counts all property fields (including deleted) for a specific target. -func (pas *PropertyAccessService) CountAllPropertyFieldsForTarget(groupID, targetType, targetID string) (int64, error) { - return pas.propertyService.countAllPropertyFieldsForTarget(groupID, targetType, targetID) -} - -// SearchPropertyFields searches for property fields based on the given options. -// Field details are filtered based on the caller's access permissions. -func (pas *PropertyAccessService) SearchPropertyFields(callerID string, groupID string, opts model.PropertyFieldSearchOpts) ([]*model.PropertyField, error) { - fields, err := pas.propertyService.searchPropertyFields(groupID, opts) - if err != nil { - return nil, fmt.Errorf("SearchPropertyFields: %w", err) - } - - return pas.applyFieldReadAccessControlToList(fields, callerID), nil -} - -// UpdatePropertyField updates a property field. +// PreUpdatePropertyField enforces access control on field updates. // Checks write access and ensures source_plugin_id is not changed. -func (pas *PropertyAccessService) UpdatePropertyField(callerID string, groupID string, field *model.PropertyField) (*model.PropertyField, error) { - // Get existing field to check access - existingField, existsErr := pas.propertyService.getPropertyField(groupID, field.ID) - if existsErr != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", existsErr) +func (h *AccessControlHook) PreUpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { + if !h.isGroupManaged(groupID) { + return field, nil } - // Check write access - if err := pas.checkFieldWriteAccess(existingField, callerID); err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) - } + callerID := h.extractCallerID(rctx) - // Ensure source_plugin_id hasn't changed - if err := pas.ensureSourcePluginIDUnchanged(existingField, field); err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) - } - - // Validate protected field update - if err := pas.validateProtectedFieldUpdate(field, callerID); err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) - } - - // Validate access mode - if err := model.ValidatePropertyFieldAccessMode(field); err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) - } - - result, err := pas.propertyService.updatePropertyField(groupID, field) + existingField, err := h.propertyService.getPropertyField(groupID, field.ID) if err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) + return nil, err } - return result, nil + + if err := h.checkFieldWriteAccess(existingField, callerID); err != nil { + return nil, err + } + + if err := h.ensureSourcePluginIDUnchanged(existingField, field); err != nil { + return nil, err + } + + if err := h.validateProtectedFieldUpdate(field, callerID); err != nil { + return nil, err + } + + if err := model.ValidatePropertyFieldAccessMode(field); err != nil { + return nil, fmt.Errorf("%s: %w", err.Error(), ErrInvalidAccessMode) + } + + return field, nil } -// UpdatePropertyFields updates multiple property fields. -// Checks write access for all fields atomically before updating any. -func (pas *PropertyAccessService) UpdatePropertyFields(callerID string, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, error) { - if len(fields) == 0 { - return fields, nil, nil +// PreUpdatePropertyFields enforces access control on batch field updates. +// Checks write access for all fields atomically before allowing any updates. +func (h *AccessControlHook) PreUpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if len(fields) == 0 || !h.isGroupManaged(groupID) { + return fields, nil } + callerID := h.extractCallerID(rctx) + // Get field IDs fieldIDs := make([]string, len(fields)) for i, field := range fields { fieldIDs[i] = field.ID } - // Fetch existing fields - existingFields, existsErr := pas.propertyService.getPropertyFields(groupID, fieldIDs) - if existsErr != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: %w", existsErr) + existingFields, err := h.propertyService.getPropertyFields(groupID, fieldIDs) + if err != nil { + return nil, err } - // Build map for easy lookup existingFieldMap := make(map[string]*model.PropertyField, len(existingFields)) for _, field := range existingFields { existingFieldMap[field.ID] = field } - // Check write access for all fields before updating any for _, field := range fields { existingField, exists := existingFieldMap[field.ID] if !exists { - return nil, nil, fmt.Errorf("field %s not found", field.ID) + return nil, fmt.Errorf("field %s: %w", field.ID, ErrFieldNotFound) } - // Check write access - if err := pas.checkFieldWriteAccess(existingField, callerID); err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: field %s: %w", field.ID, err) + if err := h.checkFieldWriteAccess(existingField, callerID); err != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, err) } - // Ensure source_plugin_id hasn't changed - if err := pas.ensureSourcePluginIDUnchanged(existingField, field); err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: field %s: %w", field.ID, err) + if err := h.ensureSourcePluginIDUnchanged(existingField, field); err != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, err) } - // Validate protected field update - if err := pas.validateProtectedFieldUpdate(field, callerID); err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: field %s: %w", field.ID, err) + if err := h.validateProtectedFieldUpdate(field, callerID); err != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, err) } - // Validate access mode if err := model.ValidatePropertyFieldAccessMode(field); err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: field %s: %w", field.ID, err) + return nil, fmt.Errorf("field %s: %s: %w", field.ID, err.Error(), ErrInvalidAccessMode) } } - // All checks passed - proceed with update - requested, propagated, err := pas.propertyService.updatePropertyFields(groupID, fields) - if err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: %w", err) - } - return requested, propagated, nil + return fields, nil } -// DeletePropertyField deletes a property field and all its values. -// Checks delete access before allowing deletion. -func (pas *PropertyAccessService) DeletePropertyField(callerID string, groupID, id string) error { - // Get existing field to check access - existingField, err := pas.propertyService.getPropertyField(groupID, id) - if err != nil { - return fmt.Errorf("DeletePropertyField: %w", err) - } - - // Check delete access - if err := pas.checkFieldDeleteAccess(existingField, callerID); err != nil { - return fmt.Errorf("DeletePropertyField: %w", err) - } - - if err := pas.propertyService.deletePropertyField(groupID, id); err != nil { - return fmt.Errorf("DeletePropertyField: %w", err) - } +// PreCountPropertyFields is a no-op — counts don't expose per-row metadata, +// so access control doesn't apply. License gating happens in LicenseCheckHook. +func (h *AccessControlHook) PreCountPropertyFields(_ request.CTX, _ string) error { return nil } -// Property Value Methods - -// CreatePropertyValue creates a new property value. -// Checks write access before allowing the creation. -func (pas *PropertyAccessService) CreatePropertyValue(callerID string, value *model.PropertyValue) (*model.PropertyValue, error) { - // Get the associated field to check access - field, err := pas.propertyService.getPropertyField(value.GroupID, value.FieldID) - if err != nil { - return nil, fmt.Errorf("CreatePropertyValue: %w", err) - } - - // Check write access - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("CreatePropertyValue: %w", err) - } - - result, err := pas.propertyService.createPropertyValue(value) - if err != nil { - return nil, fmt.Errorf("CreatePropertyValue: %w", err) - } - return result, nil -} - -// CreatePropertyValues creates multiple property values. -// Checks write access for all fields atomically before creating any values. -func (pas *PropertyAccessService) CreatePropertyValues(callerID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { - fieldMap, err := pas.getFieldsForValues(values) - if err != nil { - return nil, fmt.Errorf("CreatePropertyValues: %w", err) - } - - // Check write access for all fields before creating any values - for _, value := range values { - field, exists := fieldMap[value.FieldID] - if !exists { - return nil, fmt.Errorf("CreatePropertyValues: field %s not found", value.FieldID) - } - - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("CreatePropertyValues: field %s: %w", value.FieldID, err) - } - } - - // All checks passed - proceed with creation - result, err := pas.propertyService.createPropertyValues(values) - if err != nil { - return nil, fmt.Errorf("CreatePropertyValues: %w", err) - } - return result, nil -} - -// GetPropertyValue retrieves a property value by ID. -// Returns (nil, nil) if the value exists but the caller doesn't have access. -func (pas *PropertyAccessService) GetPropertyValue(callerID string, groupID, id string) (*model.PropertyValue, error) { - value, err := pas.propertyService.getPropertyValue(groupID, id) - if err != nil { - return nil, fmt.Errorf("GetPropertyValue: %w", err) - } - - // Apply access control filtering - filtered, err := pas.applyValueReadAccessControl([]*model.PropertyValue{value}, callerID) - if err != nil { - return nil, fmt.Errorf("GetPropertyValue: %w", err) - } - - // If the value was filtered out, return nil - if len(filtered) == 0 { - return nil, nil - } - - return filtered[0], nil -} - -// GetPropertyValues retrieves multiple property values by their IDs. -// Values the caller doesn't have access to are silently filtered out. -func (pas *PropertyAccessService) GetPropertyValues(callerID string, groupID string, ids []string) ([]*model.PropertyValue, error) { - values, err := pas.propertyService.getPropertyValues(groupID, ids) - if err != nil { - return nil, fmt.Errorf("GetPropertyValues: %w", err) - } - - // Apply access control filtering - filtered, err := pas.applyValueReadAccessControl(values, callerID) - if err != nil { - return nil, fmt.Errorf("GetPropertyValues: %w", err) - } - return filtered, nil -} - -// SearchPropertyValues searches for property values based on the given options. -// Values the caller doesn't have access to are silently filtered out. -func (pas *PropertyAccessService) SearchPropertyValues(callerID string, groupID string, opts model.PropertyValueSearchOpts) ([]*model.PropertyValue, error) { - values, err := pas.propertyService.searchPropertyValues(groupID, opts) - if err != nil { - return nil, fmt.Errorf("SearchPropertyValues: %w", err) - } - - // Apply access control filtering - filtered, err := pas.applyValueReadAccessControl(values, callerID) - if err != nil { - return nil, fmt.Errorf("SearchPropertyValues: %w", err) - } - return filtered, nil -} - -// UpdatePropertyValue updates a property value. -// Checks write access before allowing the update. -func (pas *PropertyAccessService) UpdatePropertyValue(callerID string, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) { - // Get the associated field to check access - field, err := pas.propertyService.getPropertyField(groupID, value.FieldID) - if err != nil { - return nil, fmt.Errorf("UpdatePropertyValue: %w", err) - } - - // Check write access - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("UpdatePropertyValue: %w", err) - } - - result, err := pas.propertyService.updatePropertyValue(groupID, value) - if err != nil { - return nil, fmt.Errorf("UpdatePropertyValue: %w", err) - } - return result, nil -} - -// UpdatePropertyValues updates multiple property values. -// Checks write access for all fields atomically before updating any values. -func (pas *PropertyAccessService) UpdatePropertyValues(callerID string, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { - if len(values) == 0 { - return values, nil - } - - fieldMap, err := pas.getFieldsForValues(values) - if err != nil { - return nil, fmt.Errorf("UpdatePropertyValues: %w", err) - } - - // Check write access for all fields before updating any values - for _, value := range values { - field, exists := fieldMap[value.FieldID] - if !exists { - return nil, fmt.Errorf("UpdatePropertyValues: field %s not found", value.FieldID) - } - - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("UpdatePropertyValues: field %s: %w", value.FieldID, err) - } - } - - // All checks passed - proceed with update - result, err := pas.propertyService.updatePropertyValues(groupID, values) - if err != nil { - return nil, fmt.Errorf("UpdatePropertyValues: %w", err) - } - return result, nil -} - -// UpsertPropertyValue creates or updates a property value. -// Checks write access before allowing the upsert. -func (pas *PropertyAccessService) UpsertPropertyValue(callerID string, value *model.PropertyValue) (*model.PropertyValue, error) { - // Get the associated field to check access - field, err := pas.propertyService.getPropertyField(value.GroupID, value.FieldID) - if err != nil { - return nil, fmt.Errorf("UpsertPropertyValue: %w", err) - } - - // Check write access (works for both create and update) - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("UpsertPropertyValue: %w", err) - } - - result, err := pas.propertyService.upsertPropertyValue(value) - if err != nil { - return nil, fmt.Errorf("UpsertPropertyValue: %w", err) - } - return result, nil -} - -// UpsertPropertyValues creates or updates multiple property values. -// Checks write access for all fields atomically before upserting any values. -func (pas *PropertyAccessService) UpsertPropertyValues(callerID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { - if len(values) == 0 { - return values, nil - } - - fieldMap, err := pas.getFieldsForValues(values) - if err != nil { - return nil, fmt.Errorf("UpsertPropertyValues: %w", err) - } - - // Check write access for all fields before upserting any values - for _, value := range values { - field, exists := fieldMap[value.FieldID] - if !exists { - return nil, fmt.Errorf("UpsertPropertyValues: field %s not found", value.FieldID) - } - - if err = pas.checkFieldWriteAccess(field, callerID); err != nil { - return nil, fmt.Errorf("UpsertPropertyValues: field %s: %w", value.FieldID, err) - } - } - - // All checks passed - proceed with upsert - result, err := pas.propertyService.upsertPropertyValues(values) - if err != nil { - return nil, fmt.Errorf("UpsertPropertyValues: %w", err) - } - return result, nil -} - -// DeletePropertyValue deletes a property value. -// Checks write access before allowing deletion. -func (pas *PropertyAccessService) DeletePropertyValue(callerID string, groupID, id string) error { - // Get the value to find its field ID - value, err := pas.propertyService.getPropertyValue(groupID, id) - if err != nil { - // Value doesn't exist - return nil to match original behavior +// PreDeletePropertyField enforces access control on field deletion. +func (h *AccessControlHook) PreDeletePropertyField(rctx request.CTX, groupID string, id string) error { + if !h.isGroupManaged(groupID) { return nil } - // Get the associated field to check access - field, err := pas.propertyService.getPropertyField(groupID, value.FieldID) + callerID := h.extractCallerID(rctx) + + existingField, err := h.propertyService.getPropertyField(groupID, id) if err != nil { - return fmt.Errorf("DeletePropertyValue: %w", err) + return err } - // Check write access - if err := pas.checkFieldWriteAccess(field, callerID); err != nil { - return fmt.Errorf("DeletePropertyValue: %w", err) - } - - if err := pas.propertyService.deletePropertyValue(groupID, id); err != nil { - return fmt.Errorf("DeletePropertyValue: %w", err) - } - return nil + return h.checkFieldDeleteAccess(existingField, callerID) } -// DeletePropertyValuesForTarget deletes all property values for a specific target. -// Checks write access for all affected fields atomically before deleting. -func (pas *PropertyAccessService) DeletePropertyValuesForTarget(callerID string, groupID string, targetType string, targetID string) error { +// PostUpdatePropertyFields is a no-op for access control; cleanup of dependent +// values is handled by TypeChangeValueCleanupHook. +func (h *AccessControlHook) PostUpdatePropertyFields(_ request.CTX, _ string, _, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + return requested, propagated, nil, nil +} + +// Field Post-Hooks + +// PostGetPropertyField applies read access control to a single field. +func (h *AccessControlHook) PostGetPropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if !h.isGroupManaged(field.GroupID) { + return field, nil + } + + callerID := h.extractCallerID(rctx) + return h.applyFieldReadAccessControl(field, callerID), nil +} + +// PostGetPropertyFields applies read access control to a list of fields. +// All fields in a batch share the same GroupID (enforced by the public API). +func (h *AccessControlHook) PostGetPropertyFields(rctx request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if len(fields) == 0 { + return fields, nil + } + + if !h.isGroupManaged(fields[0].GroupID) { + return fields, nil + } + + callerID := h.extractCallerID(rctx) + return h.applyFieldReadAccessControlToList(fields, callerID), nil +} + +// Value Pre-Hooks + +// PreCreatePropertyValue enforces write access and sync locking on the value's field before creation. +func (h *AccessControlHook) PreCreatePropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if !h.isGroupManaged(value.GroupID) { + return value, nil + } + + callerID := h.extractCallerID(rctx) + + field, err := h.propertyService.getPropertyField(value.GroupID, value.FieldID) + if err != nil { + return nil, err + } + + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, err + } + + return value, nil +} + +// PreCreatePropertyValues enforces write access and sync locking for all fields atomically before creation. +// All values in a batch share the same GroupID (enforced by the public API). +func (h *AccessControlHook) PreCreatePropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 || !h.isGroupManaged(values[0].GroupID) { + return values, nil + } + + callerID := h.extractCallerID(rctx) + + fieldMap, err := h.getFieldsForValues(values) + if err != nil { + return nil, err + } + + for _, value := range values { + field, exists := fieldMap[value.FieldID] + if !exists { + return nil, fmt.Errorf("field %s: %w", value.FieldID, ErrFieldNotFound) + } + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, fmt.Errorf("field %s: %w", value.FieldID, err) + } + } + + return values, nil +} + +// PreUpdatePropertyValue enforces write access and sync locking on the value's field before update. +func (h *AccessControlHook) PreUpdatePropertyValue(rctx request.CTX, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) { + if !h.isGroupManaged(groupID) { + return value, nil + } + + callerID := h.extractCallerID(rctx) + + field, err := h.propertyService.getPropertyField(groupID, value.FieldID) + if err != nil { + return nil, err + } + + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, err + } + + return value, nil +} + +// PreUpdatePropertyValues enforces write access and sync locking for all fields atomically before update. +// All values in a batch share the same GroupID (enforced by the public API). +func (h *AccessControlHook) PreUpdatePropertyValues(rctx request.CTX, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 || !h.isGroupManaged(groupID) { + return values, nil + } + + callerID := h.extractCallerID(rctx) + + fieldMap, err := h.getFieldsForValues(values) + if err != nil { + return nil, err + } + + for _, value := range values { + field, exists := fieldMap[value.FieldID] + if !exists { + return nil, fmt.Errorf("field %s: %w", value.FieldID, ErrFieldNotFound) + } + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, fmt.Errorf("field %s: %w", value.FieldID, err) + } + } + + return values, nil +} + +// PreUpsertPropertyValue enforces write access and sync locking on the value's field before upsert. +func (h *AccessControlHook) PreUpsertPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if !h.isGroupManaged(value.GroupID) { + return value, nil + } + + callerID := h.extractCallerID(rctx) + + field, err := h.propertyService.getPropertyField(value.GroupID, value.FieldID) + if err != nil { + return nil, err + } + + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, err + } + + return value, nil +} + +// PreUpsertPropertyValues enforces write access and sync locking for all fields atomically before upsert. +// All values in a batch share the same GroupID (enforced by the public API). +func (h *AccessControlHook) PreUpsertPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 || !h.isGroupManaged(values[0].GroupID) { + return values, nil + } + + callerID := h.extractCallerID(rctx) + + fieldMap, err := h.getFieldsForValues(values) + if err != nil { + return nil, err + } + + for _, value := range values { + field, exists := fieldMap[value.FieldID] + if !exists { + return nil, fmt.Errorf("field %s: %w", value.FieldID, ErrFieldNotFound) + } + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return nil, fmt.Errorf("field %s: %w", value.FieldID, err) + } + } + + return values, nil +} + +// PreDeletePropertyValue enforces write access before deleting a value. +func (h *AccessControlHook) PreDeletePropertyValue(rctx request.CTX, groupID string, id string) error { + if !h.isGroupManaged(groupID) { + return nil + } + + callerID := h.extractCallerID(rctx) + + value, err := h.propertyService.getPropertyValue(groupID, id) + if err != nil { + return err + } + + field, err := h.propertyService.getPropertyField(groupID, value.FieldID) + if err != nil { + return err + } + + return h.checkValueWriteAccess(field, callerID) +} + +// PreDeletePropertyValuesForTarget enforces write access for all affected fields +// before deleting all values for a target. +func (h *AccessControlHook) PreDeletePropertyValuesForTarget(rctx request.CTX, groupID string, targetType string, targetID string) error { + if !h.isGroupManaged(groupID) { + return nil + } + + callerID := h.extractCallerID(rctx) + // Collect unique field IDs across all values without loading all values into memory fieldIDs := make(map[string]struct{}) var cursor model.PropertyValueSearchCursor @@ -592,7 +495,7 @@ func (pas *PropertyAccessService) DeletePropertyValuesForTarget(callerID string, for { iterations++ if iterations > propertyAccessMaxPaginationIterations { - return fmt.Errorf("DeletePropertyValuesForTarget: exceeded maximum pagination iterations (%d)", propertyAccessMaxPaginationIterations) + return fmt.Errorf("exceeded maximum pagination iterations (%d)", propertyAccessMaxPaginationIterations) } opts := model.PropertyValueSearchOpts{ @@ -605,22 +508,19 @@ func (pas *PropertyAccessService) DeletePropertyValuesForTarget(callerID string, opts.Cursor = cursor } - values, err := pas.propertyService.searchPropertyValues(groupID, opts) + values, err := h.propertyService.searchPropertyValues(groupID, opts) if err != nil { - return fmt.Errorf("DeletePropertyValuesForTarget: %w", err) + return err } - // Extract field IDs from this batch for _, value := range values { fieldIDs[value.FieldID] = struct{}{} } - // If we got fewer results than the page size, we're done if len(values) < propertyAccessPaginationPageSize { break } - // Update cursor for next page lastValue := values[len(values)-1] cursor = model.PropertyValueSearchCursor{ PropertyValueID: lastValue.ID, @@ -629,62 +529,97 @@ func (pas *PropertyAccessService) DeletePropertyValuesForTarget(callerID string, } if len(fieldIDs) == 0 { - // No values to delete - return nil to match original behavior return nil } - // Convert map to slice fieldIDSlice := make([]string, 0, len(fieldIDs)) for fieldID := range fieldIDs { fieldIDSlice = append(fieldIDSlice, fieldID) } - // Fetch all fields - fields, err := pas.propertyService.getPropertyFields(groupID, fieldIDSlice) + fields, err := h.propertyService.getPropertyFields(groupID, fieldIDSlice) if err != nil { - return fmt.Errorf("DeletePropertyValuesForTarget: %w", err) + return err } - // Check write access for all fields before deleting any values for _, field := range fields { - if err := pas.checkFieldWriteAccess(field, callerID); err != nil { - return fmt.Errorf("DeletePropertyValuesForTarget: field %s: %w", field.ID, err) + if err := h.checkValueWriteAccess(field, callerID); err != nil { + return fmt.Errorf("field %s: %w", field.ID, err) } } - // All checks passed - proceed with deletion - if err := pas.propertyService.deletePropertyValuesForTarget(groupID, targetType, targetID); err != nil { - return fmt.Errorf("DeletePropertyValuesForTarget: %w", err) - } return nil } -// DeletePropertyValuesForField deletes all property values for a specific field. -// Checks write access before allowing deletion. -func (pas *PropertyAccessService) DeletePropertyValuesForField(callerID string, groupID, fieldID string) error { - // Get the field to check access - field, err := pas.propertyService.getPropertyField(groupID, fieldID) - if err != nil { - // Field doesn't exist - return nil to match original behavior +// PreDeletePropertyValuesForField enforces write access before deleting all values for a field. +func (h *AccessControlHook) PreDeletePropertyValuesForField(rctx request.CTX, groupID string, fieldID string) error { + if !h.isGroupManaged(groupID) { return nil } - // Check write access - if err := pas.checkFieldWriteAccess(field, callerID); err != nil { - return fmt.Errorf("DeletePropertyValuesForField: %w", err) + callerID := h.extractCallerID(rctx) + + field, err := h.propertyService.getPropertyField(groupID, fieldID) + if err != nil { + return err } - if err := pas.propertyService.deletePropertyValuesForField(groupID, fieldID); err != nil { - return fmt.Errorf("DeletePropertyValuesForField: %w", err) + return h.checkValueWriteAccess(field, callerID) +} + +// Value Post-Hooks + +// PostGetPropertyValue applies read access control to a single value. +// Returns nil if the caller doesn't have access. +func (h *AccessControlHook) PostGetPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if value == nil { + return nil, nil } - return nil + if !h.isGroupManaged(value.GroupID) { + return value, nil + } + + callerID := h.extractCallerID(rctx) + + filtered, err := h.applyValueReadAccessControl([]*model.PropertyValue{value}, callerID) + if err != nil { + return nil, err + } + + if len(filtered) == 0 { + return nil, nil + } + + return filtered[0], nil +} + +// PostGetPropertyValues applies read access control to a list of values. +// Values the caller doesn't have access to are silently filtered out. +// All values in a batch share the same GroupID (enforced by the public API). +func (h *AccessControlHook) PostGetPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 || !h.isGroupManaged(values[0].GroupID) { + return values, nil + } + + callerID := h.extractCallerID(rctx) + + return h.applyValueReadAccessControl(values, callerID) } // Access Control Helper Methods +// extractCallerID gets the caller ID from a request context using the property service's extractor. +func (h *AccessControlHook) extractCallerID(rctx request.CTX) string { + return h.propertyService.extractCallerID(rctx) +} + +// isCallerPlugin checks whether the callerID corresponds to an installed plugin. +func (h *AccessControlHook) isCallerPlugin(callerID string) bool { + return callerID != "" && h.pluginChecker != nil && h.pluginChecker(callerID) +} + // getSourcePluginID extracts the source_plugin_id from a PropertyField's attrs. -// Returns empty string if not set. -func (pas *PropertyAccessService) getSourcePluginID(field *model.PropertyField) string { +func (h *AccessControlHook) getSourcePluginID(field *model.PropertyField) string { if field.Attrs == nil { return "" } @@ -692,57 +627,60 @@ func (pas *PropertyAccessService) getSourcePluginID(field *model.PropertyField) return sourcePluginID } -// checkUnrestrictedFieldReadAccess checks if the given caller can read a PropertyField without restrictions. -// Returns true if the caller has unrestricted read access (public field or source plugin). -// Returns an error if access requires filtering or should be denied entirely. -func (pas *PropertyAccessService) hasUnrestrictedFieldReadAccess(field *model.PropertyField, callerID string) bool { - accessMode := field.GetAccessMode() +// getAccessMode extracts the access_mode from a PropertyField's attrs. +func (h *AccessControlHook) getAccessMode(field *model.PropertyField) string { + if field.Attrs == nil { + return model.PropertyAccessModePublic + } + accessMode, ok := field.Attrs[model.PropertyAttrsAccessMode].(string) + if !ok { + return model.PropertyAccessModePublic + } + return accessMode +} + +// hasUnrestrictedFieldReadAccess checks if the given caller can read a PropertyField without restrictions. +// Returns true if the caller has unrestricted read access (public field or source plugin). +func (h *AccessControlHook) hasUnrestrictedFieldReadAccess(field *model.PropertyField, callerID string) bool { + accessMode := h.getAccessMode(field) - // Public fields are readable by everyone without restrictions if accessMode == model.PropertyAccessModePublic { return true } - // Source plugin always has unrestricted access to fields they created - sourcePluginID := pas.getSourcePluginID(field) + sourcePluginID := h.getSourcePluginID(field) if sourcePluginID != "" && sourcePluginID == callerID { return true } - // All other cases require filtering or access denial return false } // ensureSourcePluginIDUnchanged checks that the source_plugin_id attribute hasn't changed between fields. -// Used during field updates to ensure source_plugin_id is immutable. -// Returns nil if unchanged, or an error if source_plugin_id was modified. -func (pas *PropertyAccessService) ensureSourcePluginIDUnchanged(existingField, updatedField *model.PropertyField) error { - existingSourcePluginID := pas.getSourcePluginID(existingField) - updatedSourcePluginID := pas.getSourcePluginID(updatedField) +func (h *AccessControlHook) ensureSourcePluginIDUnchanged(existingField, updatedField *model.PropertyField) error { + existingSourcePluginID := h.getSourcePluginID(existingField) + updatedSourcePluginID := h.getSourcePluginID(updatedField) if existingSourcePluginID != updatedSourcePluginID { - return fmt.Errorf("source_plugin_id is immutable and cannot be changed from '%s' to '%s'", existingSourcePluginID, updatedSourcePluginID) + return fmt.Errorf("source_plugin_id is immutable and cannot be changed from '%s' to '%s': %w", existingSourcePluginID, updatedSourcePluginID, ErrAccessDenied) } return nil } // validateProtectedFieldUpdate validates that a field can be updated to protected=true. -// Prevents creating orphaned protected fields (protected=true but no source_plugin_id). -// Also ensures only the source plugin can set protected=true on fields with a source_plugin_id. -// Returns nil if the update is valid, or an error if it should be rejected. -func (pas *PropertyAccessService) validateProtectedFieldUpdate(updatedField *model.PropertyField, callerID string) error { +func (h *AccessControlHook) validateProtectedFieldUpdate(updatedField *model.PropertyField, callerID string) error { if !model.IsPropertyFieldProtected(updatedField) { return nil } - sourcePluginID := pas.getSourcePluginID(updatedField) + sourcePluginID := h.getSourcePluginID(updatedField) if sourcePluginID == "" { - return fmt.Errorf("cannot set protected=true on a field without a source_plugin_id") + return fmt.Errorf("cannot set protected=true on a field without a source_plugin_id: %w", ErrAccessDenied) } if sourcePluginID != callerID { - return fmt.Errorf("cannot set protected=true: only source plugin '%s' can modify this field", sourcePluginID) + return fmt.Errorf("cannot set protected=true: only source plugin '%s' can modify this field: %w", sourcePluginID, ErrAccessDenied) } return nil @@ -750,21 +688,18 @@ func (pas *PropertyAccessService) validateProtectedFieldUpdate(updatedField *mod // checkFieldWriteAccess checks if the given caller can modify a PropertyField. // IMPORTANT: Always pass the existing field fetched from the database, not a field provided by the caller. -// Returns nil if modification is allowed, or an error if denied. -func (pas *PropertyAccessService) checkFieldWriteAccess(field *model.PropertyField, callerID string) error { - // Check if field is protected +func (h *AccessControlHook) checkFieldWriteAccess(field *model.PropertyField, callerID string) error { if !model.IsPropertyFieldProtected(field) { return nil } - // Protected fields can only be modified by the source plugin - sourcePluginID := pas.getSourcePluginID(field) + sourcePluginID := h.getSourcePluginID(field) if sourcePluginID == "" { - return fmt.Errorf("field %s is protected, but has no associated source plugin", field.ID) + return fmt.Errorf("field %s is protected, but has no associated source plugin: %w", field.ID, ErrAccessDenied) } if sourcePluginID != callerID { - return fmt.Errorf("field %s is protected and can only be modified by source plugin '%s'", field.ID, sourcePluginID) + return fmt.Errorf("field %s is protected and can only be modified by source plugin '%s': %w", field.ID, sourcePluginID, ErrAccessDenied) } return nil @@ -772,37 +707,66 @@ func (pas *PropertyAccessService) checkFieldWriteAccess(field *model.PropertyFie // checkFieldDeleteAccess checks if the given caller can delete a PropertyField. // IMPORTANT: Always pass the existing field fetched from the database, not a field provided by the caller. -// Returns nil if deletion is allowed, or an error if denied. -func (pas *PropertyAccessService) checkFieldDeleteAccess(field *model.PropertyField, callerID string) error { - // Check if field is protected +func (h *AccessControlHook) checkFieldDeleteAccess(field *model.PropertyField, callerID string) error { if !model.IsPropertyFieldProtected(field) { return nil } - // Protected fields can only be deleted by the source plugin - sourcePluginID := pas.getSourcePluginID(field) + sourcePluginID := h.getSourcePluginID(field) if sourcePluginID == "" { - // Protected field with no source plugin - allow deletion return nil } - // Check if the source plugin is still installed - if pas.pluginChecker != nil && !pas.pluginChecker(sourcePluginID) { - // Plugin has been uninstalled - allow deletion of orphaned field + if h.pluginChecker != nil && !h.pluginChecker(sourcePluginID) { return nil } if sourcePluginID != callerID { - return fmt.Errorf("field %s is protected and can only be modified by source plugin '%s'", field.ID, sourcePluginID) + return fmt.Errorf("field %s is protected and can only be modified by source plugin '%s': %w", field.ID, sourcePluginID, ErrAccessDenied) } return nil } +// checkSyncLock checks whether the caller is allowed to write values for a +// synced field. Synced fields have an ldap or saml attr set, and only the +// corresponding sync service (identified by well-known caller IDs) may write +// their values. +func (h *AccessControlHook) checkSyncLock(field *model.PropertyField, callerID string) error { + syncSource := model.GetPropertyFieldSyncSource(field) + if syncSource == "" { + return nil + } + + // Map sync source to the expected caller ID + var expectedCallerID string + switch syncSource { + case "ldap": + expectedCallerID = model.CallerIDLDAPSync + case "saml": + expectedCallerID = model.CallerIDSAMLSync + default: + return fmt.Errorf("field %s has unknown sync source %q: %w", field.ID, syncSource, ErrInvalidFieldAttrs) + } + + if callerID != expectedCallerID { + return fmt.Errorf("field %s is managed by %s sync and cannot be modified by caller %q: %w", field.ID, syncSource, callerID, ErrSyncLocked) + } + + return nil +} + +// checkValueWriteAccess combines the protected-field write access check and +// the sync lock check for value write operations. +func (h *AccessControlHook) checkValueWriteAccess(field *model.PropertyField, callerID string) error { + if err := h.checkFieldWriteAccess(field, callerID); err != nil { + return err + } + return h.checkSyncLock(field, callerID) +} + // getCallerValuesForField retrieves all property values for the caller on a specific field. -// This is used internally for shared_only filtering. -// Returns an empty slice if callerID is empty or if there are no values. -func (pas *PropertyAccessService) getCallerValuesForField(groupID, fieldID, callerID string) ([]*model.PropertyValue, error) { +func (h *AccessControlHook) getCallerValuesForField(groupID, fieldID, callerID string) ([]*model.PropertyValue, error) { if callerID == "" { return []*model.PropertyValue{}, nil } @@ -814,7 +778,7 @@ func (pas *PropertyAccessService) getCallerValuesForField(groupID, fieldID, call for { iterations++ if iterations > propertyAccessMaxPaginationIterations { - return nil, fmt.Errorf("getCallerValuesForField: exceeded maximum pagination iterations (%d)", propertyAccessMaxPaginationIterations) + return nil, fmt.Errorf("exceeded maximum pagination iterations (%d)", propertyAccessMaxPaginationIterations) } opts := model.PropertyValueSearchOpts{ @@ -827,19 +791,17 @@ func (pas *PropertyAccessService) getCallerValuesForField(groupID, fieldID, call opts.Cursor = cursor } - values, err := pas.propertyService.searchPropertyValues(groupID, opts) + values, err := h.propertyService.searchPropertyValues(groupID, opts) if err != nil { return nil, fmt.Errorf("failed to get caller values for field: %w", err) } allValues = append(allValues, values...) - // If we got fewer results than the page size, we're done if len(values) < propertyAccessPaginationPageSize { break } - // Update cursor for next page lastValue := values[len(values)-1] cursor = model.PropertyValueSearchCursor{ PropertyValueID: lastValue.ID, @@ -851,10 +813,7 @@ func (pas *PropertyAccessService) getCallerValuesForField(groupID, fieldID, call } // extractOptionIDsFromValue parses a JSON value and extracts option IDs into a set. -// For select fields: returns a set with one option ID -// For multiselect fields: returns a set with multiple option IDs -// Returns nil if value is empty, or an error if field type is not select/multiselect. -func (pas *PropertyAccessService) extractOptionIDsFromValue(fieldType model.PropertyFieldType, value []byte) (map[string]struct{}, error) { +func (h *AccessControlHook) extractOptionIDsFromValue(fieldType model.PropertyFieldType, value []byte) (map[string]struct{}, error) { if len(value) == 0 { return nil, nil } @@ -889,8 +848,14 @@ func (pas *PropertyAccessService) extractOptionIDsFromValue(fieldType model.Prop return optionIDs, nil } -// copyPropertyField creates a deep copy of a PropertyField, including its Attrs map. -func (pas *PropertyAccessService) copyPropertyField(field *model.PropertyField) *model.PropertyField { +// copyPropertyField returns a copy of a PropertyField with a fresh Attrs map. +// The Attrs copy is shallow: nested slices/maps (notably Attrs["options"]) +// share backing storage with the original. That is safe today because +// filterSharedOnlyFieldOptions replaces Attrs["options"] wholesale rather +// than mutating in place. A future hook that mutates a nested value in the +// returned copy would also mutate the caller's original — deep-copy those +// entries if that changes. +func (h *AccessControlHook) copyPropertyField(field *model.PropertyField) *model.PropertyField { copied := *field copied.Attrs = make(model.StringInterface) if field.Attrs != nil { @@ -900,10 +865,8 @@ func (pas *PropertyAccessService) copyPropertyField(field *model.PropertyField) } // getCallerOptionIDsForField retrieves the caller's values for a field and extracts all option IDs. -// This is used for shared_only filtering to determine which options the caller has. -// Returns an empty set if callerID is empty, if there are no values, or on error. -func (pas *PropertyAccessService) getCallerOptionIDsForField(groupID, fieldID, callerID string, fieldType model.PropertyFieldType) (map[string]struct{}, error) { - callerValues, err := pas.getCallerValuesForField(groupID, fieldID, callerID) +func (h *AccessControlHook) getCallerOptionIDsForField(groupID, fieldID, callerID string, fieldType model.PropertyFieldType) (map[string]struct{}, error) { + callerValues, err := h.getCallerValuesForField(groupID, fieldID, callerID) if err != nil { return make(map[string]struct{}), err } @@ -912,10 +875,9 @@ func (pas *PropertyAccessService) getCallerOptionIDsForField(groupID, fieldID, c return make(map[string]struct{}), nil } - // Extract option IDs from caller's values callerOptionIDs := make(map[string]struct{}) for _, val := range callerValues { - optionIDs, err := pas.extractOptionIDsFromValue(fieldType, val.Value) + optionIDs, err := h.extractOptionIDsFromValue(fieldType, val.Value) if err == nil && optionIDs != nil { for optionID := range optionIDs { callerOptionIDs[optionID] = struct{}{} @@ -927,24 +889,18 @@ func (pas *PropertyAccessService) getCallerOptionIDsForField(groupID, fieldID, c } // filterSharedOnlyFieldOptions filters a field's options to only include those the caller has values for. -// Returns a new PropertyField with filtered options in the attrs. -// If the caller has no values, returns a field with empty options. -func (pas *PropertyAccessService) filterSharedOnlyFieldOptions(field *model.PropertyField, callerID string) *model.PropertyField { - // Only applies to select and multiselect fields +func (h *AccessControlHook) filterSharedOnlyFieldOptions(field *model.PropertyField, callerID string) *model.PropertyField { if field.Type != model.PropertyFieldTypeSelect && field.Type != model.PropertyFieldTypeMultiselect { return field } - // Get caller's option IDs for this field - callerOptionIDs, err := pas.getCallerOptionIDsForField(field.GroupID, field.ID, callerID, field.Type) + callerOptionIDs, err := h.getCallerOptionIDsForField(field.GroupID, field.ID, callerID, field.Type) if err != nil || len(callerOptionIDs) == 0 { - // If no values or error, return field with empty options - filteredField := pas.copyPropertyField(field) + filteredField := h.copyPropertyField(field) filteredField.Attrs[model.PropertyFieldAttributeOptions] = []any{} return filteredField } - // Get current options from field attrs if field.Attrs == nil { return field } @@ -953,13 +909,11 @@ func (pas *PropertyAccessService) filterSharedOnlyFieldOptions(field *model.Prop return field } - // Convert to slice of maps (generic option representation) optionsSlice, ok := optionsArr.([]any) if !ok { return field } - // Filter options filteredOptions := []any{} for _, opt := range optionsSlice { optMap, ok := opt.(map[string]any) @@ -975,8 +929,7 @@ func (pas *PropertyAccessService) filterSharedOnlyFieldOptions(field *model.Prop } } - // Create a new field with filtered options - filteredField := pas.copyPropertyField(field) + filteredField := h.copyPropertyField(field) filteredField.Attrs[model.PropertyFieldAttributeOptions] = filteredOptions return filteredField } @@ -990,25 +943,21 @@ func (pas *PropertyAccessService) filterSharedOnlyFieldOptions(field *model.Prop // The binary path is what protects scenarios like LDAP/SAML-synced text codenames whose // existence is itself controlled information: a caller who doesn't hold the same value // must not see the target's value through any read endpoint. -func (pas *PropertyAccessService) filterSharedOnlyValue(field *model.PropertyField, value *model.PropertyValue, callerID string) *model.PropertyValue { +func (h *AccessControlHook) filterSharedOnlyValue(field *model.PropertyField, value *model.PropertyValue, callerID string) *model.PropertyValue { if field.Type != model.PropertyFieldTypeSelect && field.Type != model.PropertyFieldTypeMultiselect { - return pas.filterSharedOnlyScalarValue(field, value, callerID) + return h.filterSharedOnlyScalarValue(field, value, callerID) } - // Get caller's option IDs for this field - callerOptionIDs, err := pas.getCallerOptionIDsForField(field.GroupID, field.ID, callerID, field.Type) + callerOptionIDs, err := h.getCallerOptionIDsForField(field.GroupID, field.ID, callerID, field.Type) if err != nil || len(callerOptionIDs) == 0 { - // No intersection possible return nil } - // Extract option IDs from target value - targetOptionIDs, err := pas.extractOptionIDsFromValue(field.Type, value.Value) + targetOptionIDs, err := h.extractOptionIDsFromValue(field.Type, value.Value) if err != nil || targetOptionIDs == nil || len(targetOptionIDs) == 0 { return nil } - // Find intersection intersection := []string{} for targetID := range targetOptionIDs { if _, exists := callerOptionIDs[targetID]; exists { @@ -1016,17 +965,14 @@ func (pas *PropertyAccessService) filterSharedOnlyValue(field *model.PropertyFie } } - // If no intersection, return nil if len(intersection) == 0 { return nil } - // Create filtered value based on field type filteredValue := *value switch field.Type { case model.PropertyFieldTypeSelect: - // For single-select, return the single matching value jsonValue, err := json.Marshal(intersection[0]) if err != nil { return nil @@ -1035,7 +981,6 @@ func (pas *PropertyAccessService) filterSharedOnlyValue(field *model.PropertyFie return &filteredValue case model.PropertyFieldTypeMultiselect: - // For multi-select, return the array of matching values jsonValue, err := json.Marshal(intersection) if err != nil { return nil @@ -1044,7 +989,6 @@ func (pas *PropertyAccessService) filterSharedOnlyValue(field *model.PropertyFie return &filteredValue default: - // Should never reach here due to check at function start return nil } } @@ -1053,12 +997,12 @@ func (pas *PropertyAccessService) filterSharedOnlyValue(field *model.PropertyFie // returns the value as-is if the caller's own stored value for the same field equals // the target's value, otherwise nil. Caller and target may legitimately store nothing, // in which case the value is hidden. -func (pas *PropertyAccessService) filterSharedOnlyScalarValue(field *model.PropertyField, value *model.PropertyValue, callerID string) *model.PropertyValue { +func (h *AccessControlHook) filterSharedOnlyScalarValue(field *model.PropertyField, value *model.PropertyValue, callerID string) *model.PropertyValue { if value == nil || len(value.Value) == 0 { return nil } - callerValues, err := pas.getCallerValuesForField(field.GroupID, field.ID, callerID) + callerValues, err := h.getCallerValuesForField(field.GroupID, field.ID, callerID) if err != nil || len(callerValues) == 0 { return nil } @@ -1079,23 +1023,19 @@ func (pas *PropertyAccessService) filterSharedOnlyScalarValue(field *model.Prope // - Source-only fields: returned with empty options if caller is not the source plugin // - Shared-only fields: returned with options filtered using filterSharedOnlyFieldOptions // - Unknown access modes: treated as source-only (secure default) -func (pas *PropertyAccessService) applyFieldReadAccessControl(field *model.PropertyField, callerID string) *model.PropertyField { - // Check if caller has unrestricted access (public field or source plugin for source_only) - if pas.hasUnrestrictedFieldReadAccess(field, callerID) { - // Unrestricted access - return as-is +func (h *AccessControlHook) applyFieldReadAccessControl(field *model.PropertyField, callerID string) *model.PropertyField { + if h.hasUnrestrictedFieldReadAccess(field, callerID) { return field } - // Access requires filtering - accessMode := field.GetAccessMode() + accessMode := h.getAccessMode(field) - // Shared-only fields: use existing helper to filter options if accessMode == model.PropertyAccessModeSharedOnly { - return pas.filterSharedOnlyFieldOptions(field, callerID) + return h.filterSharedOnlyFieldOptions(field, callerID) } // Source-only or unknown: return with empty options (secure default) - filteredField := pas.copyPropertyField(field) + filteredField := h.copyPropertyField(field) if field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect { filteredField.Attrs[model.PropertyFieldAttributeOptions] = []any{} } @@ -1103,29 +1043,25 @@ func (pas *PropertyAccessService) applyFieldReadAccessControl(field *model.Prope } // applyFieldReadAccessControlToList applies read access control to a list of fields. -// Returns a new list with each field's options filtered based on the caller's access permissions. -func (pas *PropertyAccessService) applyFieldReadAccessControlToList(fields []*model.PropertyField, callerID string) []*model.PropertyField { +func (h *AccessControlHook) applyFieldReadAccessControlToList(fields []*model.PropertyField, callerID string) []*model.PropertyField { if len(fields) == 0 { return fields } filtered := make([]*model.PropertyField, 0, len(fields)) for _, field := range fields { - filtered = append(filtered, pas.applyFieldReadAccessControl(field, callerID)) + filtered = append(filtered, h.applyFieldReadAccessControl(field, callerID)) } return filtered } // getFieldsForValues fetches all unique fields associated with the given values. -// Returns a map of fieldID -> PropertyField. -// Returns an error if any field cannot be fetched. -func (pas *PropertyAccessService) getFieldsForValues(values []*model.PropertyValue) (map[string]*model.PropertyField, error) { +func (h *AccessControlHook) getFieldsForValues(values []*model.PropertyValue) (map[string]*model.PropertyField, error) { if len(values) == 0 { return make(map[string]*model.PropertyField), nil } - // Get unique field IDs and group ID groupAndFieldIDs := make(map[string]map[string]struct{}) for _, value := range values { if groupAndFieldIDs[value.GroupID] == nil { @@ -1136,19 +1072,16 @@ func (pas *PropertyAccessService) getFieldsForValues(values []*model.PropertyVal fieldMap := make(map[string]*model.PropertyField) for groupID, fieldIDs := range groupAndFieldIDs { - // Convert field map to slice fieldIDSlice := make([]string, 0, len(fieldIDs)) for fieldID := range fieldIDs { fieldIDSlice = append(fieldIDSlice, fieldID) } - // Fetch all fields - fields, err := pas.propertyService.getPropertyFields(groupID, fieldIDSlice) + fields, err := h.propertyService.getPropertyFields(groupID, fieldIDSlice) if err != nil { return nil, fmt.Errorf("failed to fetch fields for values: %w", err) } - // Build map for easy lookup for _, field := range fields { fieldMap[field.ID] = field } @@ -1158,20 +1091,16 @@ func (pas *PropertyAccessService) getFieldsForValues(values []*model.PropertyVal } // applyValueReadAccessControl applies read access control to a list of values. -// Returns a new list containing only the values the caller can access, with shared_only values filtered. -// Values are silently filtered out if the caller doesn't have access. -func (pas *PropertyAccessService) applyValueReadAccessControl(values []*model.PropertyValue, callerID string) ([]*model.PropertyValue, error) { +func (h *AccessControlHook) applyValueReadAccessControl(values []*model.PropertyValue, callerID string) ([]*model.PropertyValue, error) { if len(values) == 0 { return values, nil } - // Fetch all associated fields - fieldMap, err := pas.getFieldsForValues(values) + fieldMap, err := h.getFieldsForValues(values) if err != nil { return nil, fmt.Errorf("applyValueReadAccessControl: %w", err) } - // Filter values based on field access filtered := make([]*model.PropertyValue, 0, len(values)) for _, value := range values { field, exists := fieldMap[value.FieldID] @@ -1179,19 +1108,15 @@ func (pas *PropertyAccessService) applyValueReadAccessControl(values []*model.Pr return nil, fmt.Errorf("applyValueReadAccessControl: field not found for value %s", value.ID) } - accessMode := field.GetAccessMode() + accessMode := h.getAccessMode(field) - // Check if caller can read this value - if pas.hasUnrestrictedFieldReadAccess(field, callerID) { - // Caller has unrestricted access (public or source plugin) - include as-is + if h.hasUnrestrictedFieldReadAccess(field, callerID) { filtered = append(filtered, value) } else if accessMode == model.PropertyAccessModeSharedOnly { - // Shared-only mode: apply filtering - filteredValue := pas.filterSharedOnlyValue(field, value, callerID) + filteredValue := h.filterSharedOnlyValue(field, value, callerID) if filteredValue != nil { filtered = append(filtered, filteredValue) } - // If filteredValue is nil, skip this value (no intersection) } // For source_only mode where caller is not the source, skip the value } diff --git a/server/channels/app/properties/access_control_attribute_validation.go b/server/channels/app/properties/access_control_attribute_validation.go new file mode 100644 index 00000000000..054715ee171 --- /dev/null +++ b/server/channels/app/properties/access_control_attribute_validation.go @@ -0,0 +1,514 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +var ( + ErrInvalidFieldAttrs = errors.New("invalid field attrs") + ErrInvalidValue = errors.New("invalid property value") + ErrAdminRequired = errors.New("admin privileges required") +) + +// PermissionChecker checks whether a user has a specific permission. +// This avoids a circular dependency between the properties and app packages. +type PermissionChecker func(userID string, permission *model.Permission) bool + +// AccessControlAttributeValidationHook validates and sanitizes property field attributes +// and values for managed property groups. It owns the full attr pipeline +// for these groups: +// +// - validates field Name against the CEL-safe identifier rules +// ([model.ValidateCPAFieldName]); on update this fires only when Name +// actually changes, so pre-existing fields with non-conforming names +// remain editable on all other attrs (lenient grandfather) +// - trims whitespace on string attrs +// - applies the visibility default when unset +// - clears attrs that don't apply to the field type (options on non-select, +// ldap/saml on non-text or admin-managed fields) +// - auto-assigns IDs to options that lack one and validates option shape +// - validates visibility, value_type, managed, display_name, and sort_order +// - validates property values for text fields against value_type +// constraints (email, url, phone) +// - enforces that managed="admin" can only be set by callers with +// PermissionManageSystem, and keeps PermissionValues in sync with the +// managed attribute +// +// The hook only applies to groups whose IDs are in managedGroupIDs. +type AccessControlAttributeValidationHook struct { + BasePropertyHook + propertyService *PropertyService + managedGroupIDs map[string]struct{} + permissionChecker PermissionChecker +} + +var _ PropertyHook = (*AccessControlAttributeValidationHook)(nil) + +// NewAccessControlAttributeValidationHook creates a hook that validates field attributes and +// values for the given property groups. +func NewAccessControlAttributeValidationHook(ps *PropertyService, permChecker PermissionChecker, managedGroupIDs ...string) *AccessControlAttributeValidationHook { + ids := make(map[string]struct{}, len(managedGroupIDs)) + for _, id := range managedGroupIDs { + ids[id] = struct{}{} + } + return &AccessControlAttributeValidationHook{ + propertyService: ps, + managedGroupIDs: ids, + permissionChecker: permChecker, + } +} + +func (h *AccessControlAttributeValidationHook) isGroupManaged(groupID string) bool { + _, ok := h.managedGroupIDs[groupID] + return ok +} + +// sanitizeAndValidateFieldAttrs trims string attrs, applies the visibility +// default, clears attrs that don't apply to the field type, validates each +// attr, and auto-IDs+validates options for select-shaped fields. Mutates +// field.Attrs in place. +func (h *AccessControlAttributeValidationHook) sanitizeAndValidateFieldAttrs(field *model.PropertyField) error { + if field.Attrs == nil { + field.Attrs = model.StringInterface{} + } + + for _, key := range trimmedFieldAttrKeys { + if v, ok := field.Attrs[key].(string); ok { + field.Attrs[key] = strings.TrimSpace(v) + } + } + + if v, _ := field.Attrs[model.PropertyFieldAttrVisibility].(string); v == "" { + field.Attrs[model.PropertyFieldAttrVisibility] = model.PropertyFieldVisibilityWhenSet + } + + // Type-based attr clearing: select-shaped fields keep options, only text + // supports external sync, and admin-managed fields can never be synced + // (mutual exclusivity). + isSelect := field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect + isText := field.Type == model.PropertyFieldTypeText + managed, _ := field.Attrs[model.PropertyFieldAttrManaged].(string) + + if !isSelect { + delete(field.Attrs, model.PropertyFieldAttributeOptions) + } + if !isText || managed == "admin" { + delete(field.Attrs, model.PropertyFieldAttrLDAP) + delete(field.Attrs, model.PropertyFieldAttrSAML) + } + + if err := model.ValidatePropertyFieldVisibility(field); err != nil { + return fmt.Errorf("%s: %w", err.Error(), ErrInvalidFieldAttrs) + } + if isText { + if vt, _ := field.Attrs[model.PropertyFieldAttrValueType].(string); vt != "" && !model.IsValidPropertyFieldValueType(vt) { + return fmt.Errorf("invalid value_type %q: %w", vt, ErrInvalidFieldAttrs) + } + } + if managed != "" && managed != "admin" { + return fmt.Errorf("invalid managed %q (must be empty or %q): %w", managed, "admin", ErrInvalidFieldAttrs) + } + if dn, _ := field.Attrs[model.PropertyFieldAttrDisplayName].(string); utf8.RuneCountInString(dn) > model.PropertyFieldNameMaxRunes { + return fmt.Errorf("display_name exceeds max length of %d runes: %w", model.PropertyFieldNameMaxRunes, ErrInvalidFieldAttrs) + } + if isSelect { + if err := h.sanitizeAndValidateOptions(field); err != nil { + return err + } + } + if err := model.ValidatePropertyFieldSortOrder(field); err != nil { + return fmt.Errorf("%s: %w", err.Error(), ErrInvalidFieldAttrs) + } + return nil +} + +// trimmedFieldAttrKeys lists the string-valued attrs the hook trims on the +// way in. Listed explicitly rather than iterating Attrs to avoid touching +// keys this hook doesn't own (e.g. plugin-set attrs). +var trimmedFieldAttrKeys = []string{ + model.PropertyFieldAttrVisibility, + model.PropertyFieldAttrValueType, + model.PropertyFieldAttrManaged, + model.PropertyFieldAttrLDAP, + model.PropertyFieldAttrSAML, + model.PropertyFieldAttrDisplayName, +} + +// sanitizeAndValidateOptions canonicalizes the options attr to the typed +// option slice, auto-assigns IDs to options without one, and validates the +// resulting shape. The JSON round-trip handles both the typed-slice form +// (when the request decoded into a wrapper struct) and the []map[string]any +// form (after a generic JSON decode or DB read). +func (h *AccessControlAttributeValidationHook) sanitizeAndValidateOptions(field *model.PropertyField) error { + rawOptions, ok := field.Attrs[model.PropertyFieldAttributeOptions] + if !ok || rawOptions == nil { + return nil + } + + data, err := json.Marshal(rawOptions) + if err != nil { + return fmt.Errorf("invalid options: %s: %w", err, ErrInvalidFieldAttrs) + } + var options model.PropertyOptions[*model.CustomProfileAttributesSelectOption] + if err := json.Unmarshal(data, &options); err != nil { + return fmt.Errorf("invalid options: %s: %w", err, ErrInvalidFieldAttrs) + } + + for i := range options { + if options[i].ID == "" { + options[i].ID = model.NewId() + } + } + if err := options.IsValid(); err != nil { + return fmt.Errorf("invalid options: %s: %w", err, ErrInvalidFieldAttrs) + } + + field.Attrs[model.PropertyFieldAttributeOptions] = options + return nil +} + +// enforceGroupPermissions pins schema-edit permissions for fields in +// managed groups and applies the managed=admin upgrade to PermissionValues: +// - PermissionField and PermissionOptions are always set to sysadmin so +// that only admins can modify field definitions and options. +// - When managed="admin", PermissionValues is set to sysadmin. This is +// gated on PermissionManageSystem; callers without an identifiable +// caller ID (e.g. internal callers with no session on rctx) are +// treated as non-admin and rejected. +// - Otherwise, PermissionValues is left as-is when set, and default-filled +// by ObjectType when nil (member for user fields, sysadmin for system +// and template). Caller pins are never downgraded. +func (h *AccessControlAttributeValidationHook) enforceGroupPermissions(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + sysadmin := model.PermissionLevelSysadmin + + if managed, _ := field.Attrs[model.PropertyFieldAttrManaged].(string); managed == "admin" { + // Verify the caller has admin privileges. Default-deny if the + // permission checker isn't wired up or if the caller is + // unidentifiable — we never silently promote to sysadmin. + if h.permissionChecker == nil { + return nil, fmt.Errorf("missing permission to set managed=admin: no permission checker configured: %w", ErrAdminRequired) + } + callerID := h.propertyService.extractCallerID(rctx) + if callerID == "" || !h.permissionChecker(callerID, model.PermissionManageSystem) { + return nil, fmt.Errorf("missing permission to set managed=admin: only system admins can set managed=admin: %w", ErrAdminRequired) + } + field.PermissionValues = &sysadmin + } else if field.PermissionValues == nil { + defaultLevel := defaultPermissionValuesForObjectType(field.ObjectType) + field.PermissionValues = &defaultLevel + } + + // Fields in managed groups always require sysadmin for field/options edits. + field.PermissionField = &sysadmin + field.PermissionOptions = &sysadmin + + return field, nil +} + +// defaultPermissionValuesForObjectType returns the PermissionValues level a +// field should default to when the caller doesn't pin one. User fields are +// member-writable so users can set their own values; system and template +// fields attach to admin-owned scopes and require sysadmin. +func defaultPermissionValuesForObjectType(objectType string) model.PermissionLevel { + switch objectType { + case model.PropertyFieldObjectTypeSystem, model.PropertyFieldObjectTypeTemplate: + return model.PermissionLevelSysadmin + default: + return model.PermissionLevelMember + } +} + +func (h *AccessControlAttributeValidationHook) PreCreatePropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if !h.isGroupManaged(field.GroupID) { + return field, nil + } + + // Names in managed groups are referenced from ABAC policy expressions + // (user.attributes.), so they must satisfy the CEL grammar and + // avoid CEL reserved words. Returning the AppError directly preserves + // its specific i18n key through the HTTP layer's mapPropertyServiceError + // fallback (no sentinel wrap). + if appErr := model.ValidateCPAFieldName(field.Name); appErr != nil { + return nil, appErr + } + + if err := h.sanitizeAndValidateFieldAttrs(field); err != nil { + return nil, err + } + + return h.enforceGroupPermissions(rctx, field) +} + +func (h *AccessControlAttributeValidationHook) PreUpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { + if !h.isGroupManaged(groupID) { + return field, nil + } + + // Lenient grandfather: only validate Name against CEL rules when it + // actually changes, so pre-existing fields whose names predate this + // validation remain editable on all other attrs. + existing, err := h.propertyService.getPropertyField(groupID, field.ID) + if err != nil { + return nil, err + } + if existing.Name != field.Name { + if appErr := model.ValidateCPAFieldName(field.Name); appErr != nil { + return nil, appErr + } + } + + if err := h.sanitizeAndValidateFieldAttrs(field); err != nil { + return nil, err + } + + return h.enforceGroupPermissions(rctx, field) +} + +func (h *AccessControlAttributeValidationHook) PreUpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if len(fields) == 0 || !h.isGroupManaged(groupID) { + return fields, nil + } + + // Single batched read for the lenient-grandfather name check; a missing + // ID falls through to the store, which surfaces the not-found error. + fieldIDs := make([]string, len(fields)) + for i, f := range fields { + fieldIDs[i] = f.ID + } + existingFields, err := h.propertyService.getPropertyFields(groupID, fieldIDs) + if err != nil { + return nil, err + } + existingByID := make(map[string]*model.PropertyField, len(existingFields)) + for _, ex := range existingFields { + existingByID[ex.ID] = ex + } + + for i, field := range fields { + if existing, ok := existingByID[field.ID]; ok && existing.Name != field.Name { + if appErr := model.ValidateCPAFieldName(field.Name); appErr != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, appErr) + } + } + + if err := h.sanitizeAndValidateFieldAttrs(field); err != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, err) + } + + updated, err := h.enforceGroupPermissions(rctx, field) + if err != nil { + return nil, fmt.Errorf("field %s: %w", field.ID, err) + } + fields[i] = updated + } + + return fields, nil +} + +// extractOptionIDs extracts the set of valid option IDs from a +// select or multiselect PropertyField's attrs. Returns nil if the +// field has no options. +func extractOptionIDs(field *model.PropertyField) (map[string]struct{}, error) { + if field.Attrs == nil { + return nil, nil + } + + rawOptions, ok := field.Attrs[model.PropertyFieldAttributeOptions] + if !ok || rawOptions == nil { + return nil, nil + } + + data, err := json.Marshal(rawOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal options: %w", err) + } + + var options []struct { + ID string `json:"id"` + } + if err := json.Unmarshal(data, &options); err != nil { + return nil, fmt.Errorf("invalid options format: %w", err) + } + + ids := make(map[string]struct{}, len(options)) + for _, opt := range options { + if opt.ID != "" { + ids[opt.ID] = struct{}{} + } + } + return ids, nil +} + +// validateValueAgainstField checks a property value against field-type +// constraints: +// - text: max length, value_type format (email, url, phone) +// - select: option ID must exist in the field's options +// - multiselect: all option IDs must exist +// - user: value must be a valid Mattermost ID +// - multiuser: all values must be valid Mattermost IDs +func (h *AccessControlAttributeValidationHook) validateValueAgainstField(field *model.PropertyField, value *model.PropertyValue) error { + switch field.Type { + case model.PropertyFieldTypeText: + var str string + if err := json.Unmarshal(value.Value, &str); err != nil { + return fmt.Errorf("expected string value: %w", err) + } + if len(strings.TrimSpace(str)) > model.PropertyFieldValueTypeTextMaxLength { + return fmt.Errorf("text value exceeds maximum length of %d characters", model.PropertyFieldValueTypeTextMaxLength) + } + + valueType := model.GetPropertyFieldValueType(field) + if valueType == "" { + return nil + } + return model.ValidatePropertyValueForValueType(valueType, value.Value) + + case model.PropertyFieldTypeSelect: + var str string + if err := json.Unmarshal(value.Value, &str); err != nil { + return fmt.Errorf("expected string value for select field: %w", err) + } + if str == "" { + return nil + } + optionIDs, err := extractOptionIDs(field) + if err != nil { + return fmt.Errorf("failed to extract options: %w", err) + } + if _, ok := optionIDs[str]; !ok { + return fmt.Errorf("option %q does not exist", str) + } + + case model.PropertyFieldTypeMultiselect: + var values []string + if err := json.Unmarshal(value.Value, &values); err != nil { + return fmt.Errorf("expected string array value for multiselect field: %w", err) + } + optionIDs, err := extractOptionIDs(field) + if err != nil { + return fmt.Errorf("failed to extract options: %w", err) + } + for _, v := range values { + if _, ok := optionIDs[v]; !ok { + return fmt.Errorf("option %q does not exist", v) + } + } + + case model.PropertyFieldTypeUser: + var str string + if err := json.Unmarshal(value.Value, &str); err != nil { + return fmt.Errorf("expected string value for user field: %w", err) + } + if str != "" && !model.IsValidId(str) { + return fmt.Errorf("invalid user id") + } + + case model.PropertyFieldTypeMultiuser: + var values []string + if err := json.Unmarshal(value.Value, &values); err != nil { + return fmt.Errorf("expected string array value for multiuser field: %w", err) + } + for _, v := range values { + if !model.IsValidId(v) { + return fmt.Errorf("invalid user id: %s", v) + } + } + } + + return nil +} + +func (h *AccessControlAttributeValidationHook) validateValues(values []*model.PropertyValue) error { + if len(values) == 0 { + return nil + } + + groupID := values[0].GroupID + if !h.isGroupManaged(groupID) { + return nil + } + + // Collect unique field IDs + fieldIDSet := make(map[string]struct{}) + for _, v := range values { + fieldIDSet[v.FieldID] = struct{}{} + } + fieldIDs := make([]string, 0, len(fieldIDSet)) + for id := range fieldIDSet { + fieldIDs = append(fieldIDs, id) + } + + fields, err := h.propertyService.getPropertyFields(groupID, fieldIDs) + if err != nil { + return fmt.Errorf("failed to fetch fields for validation: %w", err) + } + + fieldMap := make(map[string]*model.PropertyField, len(fields)) + for _, f := range fields { + fieldMap[f.ID] = f + } + + for _, value := range values { + field, ok := fieldMap[value.FieldID] + if !ok { + return fmt.Errorf("field %s: %w", value.FieldID, ErrFieldNotFound) + } + if err := h.validateValueAgainstField(field, value); err != nil { + return fmt.Errorf("field %s: %s: %w", value.FieldID, err.Error(), ErrInvalidValue) + } + } + + return nil +} + +func (h *AccessControlAttributeValidationHook) PreUpsertPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if err := h.validateValues([]*model.PropertyValue{value}); err != nil { + return nil, err + } + return value, nil +} + +func (h *AccessControlAttributeValidationHook) PreUpsertPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if err := h.validateValues(values); err != nil { + return nil, err + } + return values, nil +} + +func (h *AccessControlAttributeValidationHook) PreCreatePropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if err := h.validateValues([]*model.PropertyValue{value}); err != nil { + return nil, err + } + return value, nil +} + +func (h *AccessControlAttributeValidationHook) PreCreatePropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if err := h.validateValues(values); err != nil { + return nil, err + } + return values, nil +} + +func (h *AccessControlAttributeValidationHook) PreUpdatePropertyValue(_ request.CTX, _ string, value *model.PropertyValue) (*model.PropertyValue, error) { + if err := h.validateValues([]*model.PropertyValue{value}); err != nil { + return nil, err + } + return value, nil +} + +func (h *AccessControlAttributeValidationHook) PreUpdatePropertyValues(_ request.CTX, _ string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if err := h.validateValues(values); err != nil { + return nil, err + } + return values, nil +} diff --git a/server/channels/app/properties/access_control_attribute_validation_test.go b/server/channels/app/properties/access_control_attribute_validation_test.go new file mode 100644 index 00000000000..01b80ab63b9 --- /dev/null +++ b/server/channels/app/properties/access_control_attribute_validation_test.go @@ -0,0 +1,1093 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAccessControlAttributeValidationHook(t *testing.T) { + th := Setup(t) + + group, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_attr_validation", Version: model.PropertyGroupVersionV2}) + require.NoError(t, err) + + hook := NewAccessControlAttributeValidationHook(th.service, nil, group.ID) + th.service.AddHook(hook) + + t.Run("allows valid visibility on create", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrVisibility: "always"}, + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.NotEmpty(t, created.ID) + }) + + t.Run("rejects invalid visibility on create", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrVisibility: "public"}, + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "visibility") + }) + + t.Run("rejects non-numeric sort_order on create", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrSortOrder: "not_a_number"}, + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "sort_order") + }) + + t.Run("allows numeric sort_order on create", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrSortOrder: float64(1.5)}, + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.NotEmpty(t, created.ID) + }) + + t.Run("rejects invalid visibility on update", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Attrs = model.StringInterface{model.PropertyFieldAttrVisibility: "bad"} + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "visibility") + }) + + t.Run("skips validation for unmanaged groups", func(t *testing.T) { + otherGroup, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_other_group", Version: model.PropertyGroupVersionV2}) + require.NoError(t, groupErr) + + field := &model.PropertyField{ + GroupID: otherGroup.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrVisibility: "invalid_but_ignored"}, + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.NotEmpty(t, created.ID) + }) + + t.Run("validates value_type on upsert — rejects invalid email", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "email_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttrValueType: "email", + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"not-an-email"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "email") + }) + + t.Run("validates value_type on upsert — accepts valid email", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "email_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttrValueType: "email", + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"test@example.com"`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("skips value_type validation for non-text fields", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "date_field_" + model.NewId(), + Type: model.PropertyFieldTypeDate, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"2024-01-01"`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("allows empty value even with value_type", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "email_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttrValueType: "email", + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`""`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + // Select field validation tests + + t.Run("select — accepts valid option ID", func(t *testing.T) { + optionID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "select_field_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"` + optionID + `"`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("select — rejects non-existent option ID", func(t *testing.T) { + optionID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "select_field_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"` + model.NewId() + `"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "does not exist") + }) + + t.Run("select — allows empty string value", func(t *testing.T) { + optionID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "select_field_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`""`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + // Multiselect field validation tests + + t.Run("multiselect — accepts valid option IDs", func(t *testing.T) { + optionID1 := model.NewId() + optionID2 := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiselect_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID1, "name": "Option 1"}, + map[string]any{"id": optionID2, "name": "Option 2"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`["` + optionID1 + `","` + optionID2 + `"]`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("multiselect — rejects if any option ID is invalid", func(t *testing.T) { + optionID1 := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiselect_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID1, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`["` + optionID1 + `","` + model.NewId() + `"]`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "does not exist") + }) + + t.Run("multiselect — accepts empty array", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiselect_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": model.NewId(), "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`[]`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + // User field validation tests + + t.Run("user — accepts valid user ID", func(t *testing.T) { + userID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "user_field_" + model.NewId(), + Type: model.PropertyFieldTypeUser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"` + userID + `"`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("user — rejects invalid user ID format", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "user_field_" + model.NewId(), + Type: model.PropertyFieldTypeUser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"not-a-valid-id"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "invalid user id") + }) + + t.Run("user — allows empty string", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "user_field_" + model.NewId(), + Type: model.PropertyFieldTypeUser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`""`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + // Multiuser field validation tests + + t.Run("multiuser — accepts valid user IDs", func(t *testing.T) { + userID1 := model.NewId() + userID2 := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiuser_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiuser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`["` + userID1 + `","` + userID2 + `"]`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("multiuser — rejects if any user ID is invalid", func(t *testing.T) { + validID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiuser_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiuser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`["` + validID + `","bad-id"]`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "invalid user id") + }) + + t.Run("multiuser — accepts empty array", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiuser_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiuser, + TargetType: "system", + ObjectType: "user", + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`[]`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + // Edge case: select with wrong JSON type + + t.Run("select — rejects non-string JSON value", func(t *testing.T) { + optionID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "select_field_" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`123`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "expected string value") + }) + + t.Run("multiselect — rejects non-array JSON value", func(t *testing.T) { + optionID := model.NewId() + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "multiselect_field_" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": optionID, "name": "Option 1"}, + }, + }, + }) + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"not-an-array"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "expected string array") + }) + + t.Run("upsert with unknown field id returns ErrFieldNotFound", func(t *testing.T) { + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: model.NewId(), + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"anything"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.ErrorIs(t, upsertErr, ErrFieldNotFound) + var resultsMismatchErr *store.ErrResultsMismatch + assert.ErrorAs(t, upsertErr, &resultsMismatchErr, "original store error should remain in chain") + }) + + // Group permission enforcement tests + // + // These tests run with the hook configured with a nil permissionChecker + // (see the Setup block at the top of this test function). In that + // configuration, managed="admin" is default-denied since there is no + // way to verify the caller's admin status. The "allowed" side of the + // authorization matrix is covered in TestAccessControlAttributeValidationHookManagedAuthorization. + + t.Run("create field with managed=admin is rejected when no permission checker is configured", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "managed=admin") + }) + + t.Run("create field without managed sets PermissionValues to member", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{}, + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + require.NotNil(t, created.PermissionValues) + assert.Equal(t, model.PermissionLevelMember, *created.PermissionValues) + require.NotNil(t, created.PermissionField) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionField) + require.NotNil(t, created.PermissionOptions) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionOptions) + }) + + t.Run("update field to managed=admin is rejected when no permission checker is configured", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{}, + }) + + field.Attrs = model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + } + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "managed=admin") + }) + + t.Run("update field to remove managed sets PermissionValues to member", func(t *testing.T) { + member := model.PermissionLevelMember + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + PermissionValues: &member, + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + }) + + field.Attrs = model.StringInterface{} + updated, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.NoError(t, updateErr) + require.NotNil(t, updated.PermissionValues) + assert.Equal(t, model.PermissionLevelMember, *updated.PermissionValues) + }) + + t.Run("sanitization on create: defaults visibility to when_set", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.Equal(t, model.CustomProfileAttributesVisibilityWhenSet, created.Attrs[model.CustomProfileAttributesPropertyAttrsVisibility]) + }) + + t.Run("sanitization on create: trims display_name and rejects when too long", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsDisplayName: " Department Head ", + }, + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.Equal(t, "Department Head", created.Attrs[model.CustomProfileAttributesPropertyAttrsDisplayName]) + + // Build a 256-rune string — exceeds the 255-rune cap (PropertyFieldNameMaxRunes). + tooLong := strings.Repeat("a", model.PropertyFieldNameMaxRunes+1) + bad := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.CustomProfileAttributesPropertyAttrsDisplayName: tooLong}, + } + _, badErr := th.service.CreatePropertyField(th.Context, bad) + require.Error(t, badErr) + assert.Contains(t, badErr.Error(), "display_name") + }) + + t.Run("sanitization on update: rejects display_name longer than max", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Attrs = model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsDisplayName: strings.Repeat("a", model.PropertyFieldNameMaxRunes+1), + } + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "display_name") + }) + + t.Run("sanitization on update: rejects unknown value_type on text field", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Attrs = model.StringInterface{model.PropertyFieldAttrValueType: "wat"} + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "value_type") + }) + + t.Run("sanitization on update: rejects unknown managed value", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Attrs = model.StringInterface{model.PropertyFieldAttrManaged: "kinda"} + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "managed") + }) + + t.Run("name validation on create: rejects non-CEL identifier", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "Has Space", + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + var appErr *model.AppError + require.ErrorAs(t, createErr, &appErr) + assert.Equal(t, "model.cpa_field.name.invalid_charset.app_error", appErr.Id) + }) + + t.Run("name validation on create: rejects CEL reserved word", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "for", + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + var appErr *model.AppError + require.ErrorAs(t, createErr, &appErr) + assert.Equal(t, "model.cpa_field.name.reserved_word.app_error", appErr.Id) + }) + + t.Run("name validation on create: accepts CEL-safe identifier", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: group.ID, + Name: "department_head", + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.Equal(t, "department_head", created.Name) + }) + + t.Run("name validation on update: lenient grandfather lets non-conforming name through when unchanged", func(t *testing.T) { + // Direct store insert bypasses the hook so we can seed a name that + // would fail current validation, simulating a field that predates it. + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "legacy name", + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + // Patch a different attr without touching Name — should succeed. + field.Attrs = model.StringInterface{model.PropertyFieldAttrVisibility: "always"} + updated, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.NoError(t, updateErr) + assert.Equal(t, "legacy name", updated.Name) + }) + + t.Run("name validation on update: rejects rename to non-CEL identifier", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "good_name_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Name = "Bad Name" + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + var appErr *model.AppError + require.ErrorAs(t, updateErr, &appErr) + assert.Equal(t, "model.cpa_field.name.invalid_charset.app_error", appErr.Id) + }) + + t.Run("name validation on update: rejects rename to CEL reserved word", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "good_name_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + field.Name = "in" + _, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.Error(t, updateErr) + var appErr *model.AppError + require.ErrorAs(t, updateErr, &appErr) + assert.Equal(t, "model.cpa_field.name.reserved_word.app_error", appErr.Id) + }) + + t.Run("name validation on update: accepts rename to CEL-safe identifier", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "old_name_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + newName := "new_name_" + model.NewId() + field.Name = newName + updated, _, updateErr := th.service.UpdatePropertyField(th.Context, group.ID, field) + require.NoError(t, updateErr) + assert.Equal(t, newName, updated.Name) + }) + + t.Run("name validation on batch update: lenient grandfather applies per-field", func(t *testing.T) { + grandfathered := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "still legacy", + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + renamable := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "rename_src_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + // Touch grandfathered without renaming; rename renamable to a CEL-safe + // name. Both should be accepted. + grandfathered.Attrs = model.StringInterface{model.PropertyFieldAttrVisibility: "hidden"} + newName := "rename_dst_" + model.NewId() + renamable.Name = newName + _, _, _, updateErr := th.service.UpdatePropertyFields(th.Context, group.ID, []*model.PropertyField{grandfathered, renamable}) + require.NoError(t, updateErr) + }) + + t.Run("name validation on batch update: one bad rename rejects the batch", func(t *testing.T) { + ok := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "ok_src_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + bad := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "bad_src_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + ok.Name = "ok_dst_" + model.NewId() + bad.Name = "for" // CEL reserved word + _, _, _, updateErr := th.service.UpdatePropertyFields(th.Context, group.ID, []*model.PropertyField{ok, bad}) + require.Error(t, updateErr) + var appErr *model.AppError + require.ErrorAs(t, updateErr, &appErr) + assert.Equal(t, "model.cpa_field.name.reserved_word.app_error", appErr.Id) + }) + + t.Run("text — rejects value exceeding max length", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "text_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + // Create a string longer than PropertyFieldValueTypeTextMaxLength (64) + longValue := make([]byte, 0, 70) + for range 70 { + longValue = append(longValue, 'a') + } + + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"` + string(longValue) + `"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "maximum length") + }) +} + +func TestAccessControlAttributeValidationHookManagedAuthorization(t *testing.T) { + th := Setup(t) + + group, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_managed_auth", Version: model.PropertyGroupVersionV2}) + require.NoError(t, err) + + adminUserID := model.NewId() + regularUserID := model.NewId() + + permChecker := func(userID string, perm *model.Permission) bool { + return userID == adminUserID && perm.Id == model.PermissionManageSystem.Id + } + + hook := NewAccessControlAttributeValidationHook(th.service, permChecker, group.ID) + th.service.AddHook(hook) + + t.Run("admin can create field with managed=admin", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, adminUserID) + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + created, createErr := th.service.CreatePropertyField(rctx, field) + require.NoError(t, createErr) + require.NotNil(t, created.PermissionValues) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionValues) + }) + + t.Run("non-admin is blocked from creating field with managed=admin", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, regularUserID) + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + _, createErr := th.service.CreatePropertyField(rctx, field) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "permission") + }) + + t.Run("non-admin can create field without managed attr", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, regularUserID) + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{}, + } + created, createErr := th.service.CreatePropertyField(rctx, field) + require.NoError(t, createErr) + require.NotNil(t, created.PermissionValues) + assert.Equal(t, model.PermissionLevelMember, *created.PermissionValues) + }) + + t.Run("non-admin is blocked from updating field to managed=admin", func(t *testing.T) { + // Create field as admin + adminRctx := RequestContextWithCallerID(th.Context, adminUserID) + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{}, + } + created, createErr := th.service.CreatePropertyField(adminRctx, field) + require.NoError(t, createErr) + + // Try to update as non-admin + rctx := RequestContextWithCallerID(th.Context, regularUserID) + created.Attrs = model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + } + _, _, updateErr := th.service.UpdatePropertyField(rctx, group.ID, created) + require.Error(t, updateErr) + assert.Contains(t, updateErr.Error(), "permission") + }) + + t.Run("admin can update field to managed=admin", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, adminUserID) + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{}, + } + created, createErr := th.service.CreatePropertyField(rctx, field) + require.NoError(t, createErr) + + created.Attrs = model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + } + updated, _, updateErr := th.service.UpdatePropertyField(rctx, group.ID, created) + require.NoError(t, updateErr) + require.NotNil(t, updated.PermissionValues) + assert.Equal(t, model.PermissionLevelSysadmin, *updated.PermissionValues) + }) + + t.Run("managed check skipped for unmanaged groups", func(t *testing.T) { + otherGroup, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_other_managed", Version: model.PropertyGroupVersionV2}) + require.NoError(t, groupErr) + + rctx := RequestContextWithCallerID(th.Context, regularUserID) + field := &model.PropertyField{ + GroupID: otherGroup.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + // Should succeed because the hook doesn't apply to this group + created, createErr := th.service.CreatePropertyField(rctx, field) + require.NoError(t, createErr) + // PermissionValues should NOT be set by the hook for unmanaged groups + assert.Nil(t, created.PermissionValues) + }) + + t.Run("empty caller ID is rejected (default-deny for unidentified callers)", func(t *testing.T) { + // th.Context has no caller ID set. The hook must treat this as + // non-admin and block managed=admin rather than silently + // promoting to sysadmin. + field := &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{ + model.CustomProfileAttributesPropertyAttrsManaged: "admin", + }, + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "managed=admin") + }) +} diff --git a/server/channels/app/properties/access_control_field_test.go b/server/channels/app/properties/access_control_field_test.go index 6ceec86ce78..fbb4e597605 100644 --- a/server/channels/app/properties/access_control_field_test.go +++ b/server/channels/app/properties/access_control_field_test.go @@ -35,7 +35,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, model.PropertyFieldAttributeOptions: []any{ @@ -77,7 +78,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -102,7 +104,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-2", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -127,7 +130,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-3", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -152,7 +156,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-4", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -177,7 +182,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-field", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -226,7 +232,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-field-2", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -251,7 +258,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-field-source", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsSourcePluginID: pluginID1, @@ -281,14 +289,15 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { }) t.Run("non-CPA group routes directly to PropertyService without filtering", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_routing_read", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_routing_read", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ GroupID: nonCpaGroup.ID, Name: "routing-test-non-cpa-source-only", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -315,7 +324,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "no-attrs-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: nil, } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -332,7 +342,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "empty-access-mode-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{}, } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -349,7 +360,8 @@ func TestGetPropertyFieldReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "invalid-access-mode-field", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: "invalid-mode", model.PropertyFieldAttributeOptions: []any{ @@ -382,7 +394,8 @@ func TestGetPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -394,7 +407,8 @@ func TestGetPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -410,7 +424,8 @@ func TestGetPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-field", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -486,7 +501,8 @@ func TestSearchPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-search-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -498,7 +514,8 @@ func TestSearchPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-search-field", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -514,7 +531,8 @@ func TestSearchPropertyFieldsReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-search-field", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -583,7 +601,6 @@ func TestGetPropertyFieldByNameReadAccess(t *testing.T) { pluginID := "plugin-1" userID := model.NewId() - targetID := model.NewId() rctxPlugin := RequestContextWithCallerID(th.Context, pluginID) rctxUser := RequestContextWithCallerID(th.Context, userID) @@ -593,8 +610,8 @@ func TestGetPropertyFieldByNameReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "byname-source-only", Type: model.PropertyFieldTypeSelect, - TargetType: "user", - TargetID: targetID, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -607,12 +624,12 @@ func TestGetPropertyFieldByNameReadAccess(t *testing.T) { require.NoError(t, err) // Source plugin can see options - retrieved, err := th.service.GetPropertyFieldByName(rctxPlugin, th.CPAGroupID, targetID, created.Name) + retrieved, err := th.service.GetPropertyFieldByName(rctxPlugin, th.CPAGroupID, "", created.Name) require.NoError(t, err) assert.Len(t, retrieved.Attrs[model.PropertyFieldAttributeOptions].([]any), 1) // User sees empty options - retrieved, err = th.service.GetPropertyFieldByName(rctxUser, th.CPAGroupID, targetID, created.Name) + retrieved, err = th.service.GetPropertyFieldByName(rctxUser, th.CPAGroupID, "", created.Name) require.NoError(t, err) assert.Empty(t, retrieved.Attrs[model.PropertyFieldAttributeOptions].([]any)) }) @@ -634,9 +651,11 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { t.Run("non-plugin caller can create field without source_plugin_id", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, + GroupID: th.CPAGroupID, + Name: model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxUser1, field) @@ -654,6 +673,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsSourcePluginID: "plugin-1", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(RequestContextWithCallerID(th.Context, "user-id-123"), field) @@ -670,6 +691,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxUser1, field) @@ -686,6 +709,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsSourcePluginID: "plugin-1", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -702,6 +727,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -718,6 +745,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsSourcePluginID: "", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -729,9 +758,11 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { t.Run("plugin caller auto-sets source_plugin_id", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, + GroupID: th.CPAGroupID, + Name: model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -748,6 +779,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsSourcePluginID: "malicious-plugin", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -761,7 +794,8 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: model.NewId(), Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, @@ -780,13 +814,15 @@ func TestCreatePropertyField_AccessControl(t *testing.T) { }) t.Run("non-CPA group routes directly to PropertyService without setting source_plugin_id", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_create", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_create", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ - GroupID: nonCpaGroup.ID, - Name: model.NewId(), - Type: model.PropertyFieldTypeText, + GroupID: nonCpaGroup.ID, + Name: model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } rctx := RequestContextWithCallerID(th.Context, "plugin-2") @@ -811,16 +847,18 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { t.Run("allows update of unprotected field", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: "Original Name", - Type: model.PropertyFieldTypeText, + GroupID: th.CPAGroupID, + Name: "Original Name", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) created.Name = "Updated Name" - updated, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) require.NoError(t, err) assert.Equal(t, "Updated Name", updated.Name) }) @@ -833,13 +871,15 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) created.Name = "Updated Protected Field" - updated, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) require.NoError(t, err) assert.Equal(t, "Updated Protected Field", updated.Name) }) @@ -852,13 +892,15 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) created.Name = "Attempted Update" - updated, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "protected") @@ -873,13 +915,15 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) created.Name = "Attempted Update" - updated, err := th.service.UpdatePropertyField(rctxAnon, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxAnon, th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "protected") @@ -887,10 +931,12 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { t.Run("prevents changing source_plugin_id", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: "Field", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, + GroupID: th.CPAGroupID, + Name: "Field", + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{}, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -898,7 +944,7 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { // Try to change source_plugin_id created.Attrs[model.PropertyAttrsSourcePluginID] = "plugin-2" - updated, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "immutable") @@ -906,10 +952,12 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { t.Run("prevents setting protected=true without source_plugin_id", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: "Field Without Source Plugin", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, + GroupID: th.CPAGroupID, + Name: "Field Without Source Plugin", + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{}, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) @@ -917,7 +965,7 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { // Try to set protected=true without having a source_plugin_id created.Attrs[model.PropertyAttrsProtected] = true - updated, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "cannot set protected=true") @@ -926,10 +974,12 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { t.Run("prevents non-source plugin from setting protected=true", func(t *testing.T) { field := &model.PropertyField{ - GroupID: th.CPAGroupID, - Name: "Field With Source Plugin", - Type: model.PropertyFieldTypeText, - Attrs: model.StringInterface{}, + GroupID: th.CPAGroupID, + Name: "Field With Source Plugin", + Type: model.PropertyFieldTypeText, + Attrs: model.StringInterface{}, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } // Create field via plugin-1 (sets source_plugin_id automatically) @@ -939,20 +989,20 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { // Try to set protected=true by a different plugin (plugin-2) created.Attrs[model.PropertyAttrsProtected] = true - updated, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "cannot set protected=true") assert.Contains(t, err.Error(), "plugin-1") // Verify the source plugin (plugin-1) CAN set protected=true - updated, err = th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) + updated, _, err = th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) require.NoError(t, err) assert.True(t, model.IsPropertyFieldProtected(updated)) }) t.Run("non-CPA group routes directly to PropertyService without access control", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_update", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_update", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ @@ -963,6 +1013,8 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { model.PropertyAttrsProtected: true, model.PropertyAttrsSourcePluginID: "plugin-1", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -970,7 +1022,7 @@ func TestUpdatePropertyField_WriteAccessControl(t *testing.T) { // Update with different plugin - should be allowed (no access control) created.Name = "Updated by Plugin2" - updated, err := th.service.UpdatePropertyField(rctxPlugin2, nonCpaGroup.ID, created) + updated, _, err := th.service.UpdatePropertyField(rctxPlugin2, nonCpaGroup.ID, created) require.NoError(t, err) assert.NotNil(t, updated) assert.Equal(t, "Updated by Plugin2", updated.Name) @@ -989,8 +1041,8 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") t.Run("allows bulk update of unprotected fields", func(t *testing.T) { - field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText} - field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText} + field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} + field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created1, err := th.service.CreatePropertyField(rctxPlugin1, field1) require.NoError(t, err) @@ -1000,14 +1052,14 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { created1.Name = "Updated Field1" created2.Name = "Updated Field2" - updated, _, err := th.service.UpdatePropertyFields(rctxPlugin2, th.CPAGroupID, []*model.PropertyField{created1, created2}) + updated, _, _, err := th.service.UpdatePropertyFields(rctxPlugin2, th.CPAGroupID, []*model.PropertyField{created1, created2}) require.NoError(t, err) assert.Len(t, updated, 2) }) t.Run("fails atomically when one protected field in batch", func(t *testing.T) { // Create unprotected field - field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Unprotected", Type: model.PropertyFieldTypeText} + field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Unprotected", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created1, err := th.service.CreatePropertyField(rctxPlugin1, field1) require.NoError(t, err) @@ -1019,6 +1071,8 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created2, err := th.service.CreatePropertyField(rctxPlugin1, field2) require.NoError(t, err) @@ -1027,7 +1081,7 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { created1.Name = "Updated Unprotected" created2.Name = "Updated Protected" - updated, _, err := th.service.UpdatePropertyFields(rctxPlugin2, th.CPAGroupID, []*model.PropertyField{created1, created2}) + updated, _, _, err := th.service.UpdatePropertyFields(rctxPlugin2, th.CPAGroupID, []*model.PropertyField{created1, created2}) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "protected") @@ -1046,8 +1100,8 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { rctxAnon := RequestContextWithCallerID(th.Context, "") // Create two unprotected fields without source_plugin_id - field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{}} - field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{}} + field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{}, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} + field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText, Attrs: model.StringInterface{}, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created1, err := th.service.CreatePropertyField(rctxAnon, field1) require.NoError(t, err) @@ -1058,7 +1112,7 @@ func TestUpdatePropertyFields_BulkWriteAccessControl(t *testing.T) { created1.Name = "Updated Field1" created2.Attrs[model.PropertyAttrsProtected] = true - updated, _, err := th.service.UpdatePropertyFields(rctxPlugin1, th.CPAGroupID, []*model.PropertyField{created1, created2}) + updated, _, _, err := th.service.UpdatePropertyFields(rctxPlugin1, th.CPAGroupID, []*model.PropertyField{created1, created2}) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "cannot set protected=true") @@ -1084,7 +1138,7 @@ func TestDeletePropertyField_WriteAccessControl(t *testing.T) { rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") t.Run("allows deletion of unprotected field", func(t *testing.T) { - field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Unprotected", Type: model.PropertyFieldTypeText} + field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Unprotected", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -1100,6 +1154,8 @@ func TestDeletePropertyField_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -1120,6 +1176,8 @@ func TestDeletePropertyField_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -1130,7 +1188,7 @@ func TestDeletePropertyField_WriteAccessControl(t *testing.T) { }) t.Run("non-CPA group routes directly to PropertyService without access control", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_delete", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_delete", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ @@ -1141,6 +1199,8 @@ func TestDeletePropertyField_WriteAccessControl(t *testing.T) { model.PropertyAttrsProtected: true, model.PropertyAttrsSourcePluginID: "plugin-1", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -1169,6 +1229,8 @@ func TestDeletePropertyField_OrphanedFieldDeletion(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(RequestContextWithCallerID(th.Context, "removed-plugin"), field) require.NoError(t, err) @@ -1194,6 +1256,8 @@ func TestDeletePropertyField_OrphanedFieldDeletion(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(RequestContextWithCallerID(th.Context, "installed-plugin"), field) require.NoError(t, err) @@ -1220,6 +1284,8 @@ func TestDeletePropertyField_OrphanedFieldDeletion(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(RequestContextWithCallerID(th.Context, "removed-plugin"), field) require.NoError(t, err) @@ -1230,7 +1296,7 @@ func TestDeletePropertyField_OrphanedFieldDeletion(t *testing.T) { }) created.Name = "Updated Orphaned Field" - updated, err := th.service.UpdatePropertyField(RequestContextWithCallerID(th.Context, "admin-user"), th.CPAGroupID, created) + updated, _, err := th.service.UpdatePropertyField(RequestContextWithCallerID(th.Context, "admin-user"), th.CPAGroupID, created) require.Error(t, err) assert.Nil(t, updated) assert.Contains(t, err.Error(), "protected") diff --git a/server/channels/app/properties/access_control_value_test.go b/server/channels/app/properties/access_control_value_test.go index 8b40e179b03..aed0746c55f 100644 --- a/server/channels/app/properties/access_control_value_test.go +++ b/server/channels/app/properties/access_control_value_test.go @@ -24,7 +24,7 @@ func TestCreatePropertyValue_WriteAccessControl(t *testing.T) { rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") t.Run("allows creating value for public field", func(t *testing.T) { - field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText} + field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -50,6 +50,8 @@ func TestCreatePropertyValue_WriteAccessControl(t *testing.T) { model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -75,6 +77,8 @@ func TestCreatePropertyValue_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -94,7 +98,7 @@ func TestCreatePropertyValue_WriteAccessControl(t *testing.T) { }) t.Run("non-CPA group routes directly to PropertyService without access control", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_value_create", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_value_create", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ @@ -105,6 +109,8 @@ func TestCreatePropertyValue_WriteAccessControl(t *testing.T) { model.PropertyAttrsProtected: true, model.PropertyAttrsSourcePluginID: "plugin-1", }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) @@ -137,7 +143,7 @@ func TestDeletePropertyValue_WriteAccessControl(t *testing.T) { rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") t.Run("allows deleting value for public field", func(t *testing.T) { - field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText} + field := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -163,6 +169,8 @@ func TestDeletePropertyValue_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxPlugin1, field) require.NoError(t, err) @@ -195,8 +203,8 @@ func TestDeletePropertyValuesForTarget_WriteAccessControl(t *testing.T) { rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") t.Run("allows deleting all values when caller has write access to all fields", func(t *testing.T) { - field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText} - field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText} + field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field1", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} + field2 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Field2", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created1, err := th.service.CreatePropertyField(rctxPlugin1, field1) require.NoError(t, err) @@ -218,7 +226,7 @@ func TestDeletePropertyValuesForTarget_WriteAccessControl(t *testing.T) { t.Run("fails atomically when caller lacks access to one field", func(t *testing.T) { // Create public field - field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText} + field1 := &model.PropertyField{GroupID: th.CPAGroupID, Name: "Public", Type: model.PropertyFieldTypeText, ObjectType: model.PropertyFieldObjectTypeUser, TargetType: string(model.PropertyFieldTargetLevelSystem)} created1, err := th.service.CreatePropertyField(rctxPlugin1, field1) require.NoError(t, err) @@ -230,6 +238,8 @@ func TestDeletePropertyValuesForTarget_WriteAccessControl(t *testing.T) { Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created2, err := th.service.CreatePropertyField(rctxPlugin1, field2) require.NoError(t, err) @@ -281,7 +291,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -334,7 +345,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -370,7 +382,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -414,7 +427,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-single-select", Type: model.PropertyFieldTypeSelect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -487,7 +501,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-multi-select", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -567,7 +582,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-text", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -637,7 +653,8 @@ func TestGetPropertyValueReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-only-no-values", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -669,14 +686,15 @@ func TestGetPropertyValueReadAccess(t *testing.T) { }) t.Run("non-CPA group routes directly to PropertyService without filtering", func(t *testing.T) { - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_value_read", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_value_read", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) field := &model.PropertyField{ GroupID: nonCpaGroup.ID, Name: "non-cpa-value-source-only", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -733,7 +751,8 @@ func TestGetPropertyValuesReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-bulk", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -746,7 +765,8 @@ func TestGetPropertyValuesReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-bulk", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -822,7 +842,8 @@ func TestSearchPropertyValuesReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-search", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -835,7 +856,8 @@ func TestSearchPropertyValuesReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "source-only-field-search", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -891,7 +913,8 @@ func TestSearchPropertyValuesReadAccess(t *testing.T) { GroupID: th.CPAGroupID, Name: "shared-field-search", Type: model.PropertyFieldTypeMultiselect, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSharedOnly, model.PropertyAttrsProtected: true, @@ -965,7 +988,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -977,7 +1001,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -1018,7 +1043,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "protected-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -1031,7 +1057,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "protected-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -1073,7 +1100,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-batch", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -1085,7 +1113,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "protected-field-batch", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -1134,7 +1163,7 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { t.Run("rejects values across multiple groups", func(t *testing.T) { // Register a second group - group2, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_group_create_values_2", Version: model.PropertyGroupVersionV1}) + group2, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_group_create_values_2", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) // Create fields in both groups @@ -1142,7 +1171,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "field-group1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -1154,7 +1184,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: group2.ID, Name: "field-group2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -1194,7 +1225,7 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { t.Run("rejects mixed groups before checking access control", func(t *testing.T) { // Register a third group - group3, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_group_create_values_3", Version: model.PropertyGroupVersionV1}) + group3, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_group_create_values_3", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) // Create public field in CPA group @@ -1202,7 +1233,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-multigroup", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModePublic, }, @@ -1215,7 +1247,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: group3.ID, Name: "protected-field-multigroup", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, model.PropertyAttrsProtected: true, @@ -1256,7 +1289,7 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { t.Run("non-CPA group routes directly to PropertyService without access control", func(t *testing.T) { // Register a non-CPA group - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_bulk", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_bulk", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) // Create two fields in non-CPA group @@ -1264,13 +1297,15 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: nonCpaGroup.ID, Name: "non-cpa-bulk-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } field2 := &model.PropertyField{ GroupID: nonCpaGroup.ID, Name: "non-cpa-bulk-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created1, err := th.service.CreatePropertyField(rctx1, field1) @@ -1304,7 +1339,7 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { t.Run("mixed CPA and non-CPA groups are rejected before access control", func(t *testing.T) { // Register a non-CPA group - nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_mixed", Version: model.PropertyGroupVersionV1}) + nonCpaGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "other_group_mixed", Version: model.PropertyGroupVersionV2}) require.NoError(t, err) // Create protected field in CPA group via plugin API @@ -1312,7 +1347,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "cpa-protected-mixed", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1325,7 +1361,8 @@ func TestCreatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: nonCpaGroup.ID, Name: "non-cpa-field-mixed", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } nonCpaField, err = th.service.CreatePropertyField(rctx1, nonCpaField) require.NoError(t, err) @@ -1370,7 +1407,8 @@ func TestUpdatePropertyValue_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "protected-field-for-update", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1403,7 +1441,8 @@ func TestUpdatePropertyValue_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "protected-field-for-update-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1436,7 +1475,8 @@ func TestUpdatePropertyValue_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-for-update", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) require.NoError(t, err) @@ -1480,7 +1520,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-update-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1489,7 +1530,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-update-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1533,7 +1575,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-update-fail-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1542,7 +1585,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-update-fail-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1593,7 +1637,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "mixed-update-protected-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1602,7 +1647,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "mixed-update-public-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } createdProtected, err := th.service.CreatePropertyField(rctx1, protectedField) @@ -1660,7 +1706,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "multi-owner-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1669,7 +1716,8 @@ func TestUpdatePropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "multi-owner-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1756,7 +1804,8 @@ func TestUpsertPropertyValue_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "upsert-protected-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1790,7 +1839,8 @@ func TestUpsertPropertyValue_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "upsert-protected-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1832,7 +1882,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-upsert-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1841,7 +1892,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-upsert-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1880,7 +1932,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-upsert-fail-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1889,7 +1942,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "bulk-upsert-fail-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1937,7 +1991,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "mixed-protected-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -1946,7 +2001,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "mixed-public-field", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } createdProtected, err := th.service.CreatePropertyField(rctx1, protectedField) @@ -1999,7 +2055,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "upsert-multi-owner-field-1", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -2008,7 +2065,8 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "upsert-multi-owner-field-2", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -2091,6 +2149,146 @@ func TestUpsertPropertyValues_WriteAccessControl(t *testing.T) { }) } +func TestUpsertPropertyValue_SyncLock(t *testing.T) { + th := Setup(t) + + group, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_sync_lock", Version: model.PropertyGroupVersionV1}) + require.NoError(t, err) + + hook := NewAccessControlHook(th.service, nil, group.ID) + th.service.AddHook(hook) + + ldapField := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "ldap_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrLDAP: "cn"}, + }) + + samlField := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "saml_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + Attrs: model.StringInterface{model.PropertyFieldAttrSAML: "displayName"}, + }) + + nonSyncedField := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: group.ID, + Name: "normal_field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + }) + + targetID := model.NewId() + + t.Run("blocks upsert on LDAP-synced field without caller ID", func(t *testing.T) { + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: ldapField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"test"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "ldap sync") + }) + + t.Run("allows LDAP sync service to upsert LDAP-synced field", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, model.CallerIDLDAPSync) + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: ldapField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"John Doe"`), + } + result, upsertErr := th.service.UpsertPropertyValue(rctx, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("blocks SAML sync service from writing LDAP-synced field", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, model.CallerIDSAMLSync) + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: ldapField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"wrong caller"`), + } + _, upsertErr := th.service.UpsertPropertyValue(rctx, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "ldap sync") + }) + + t.Run("allows SAML sync service to upsert SAML-synced field", func(t *testing.T) { + rctx := RequestContextWithCallerID(th.Context, model.CallerIDSAMLSync) + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: samlField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"Jane Doe"`), + } + result, upsertErr := th.service.UpsertPropertyValue(rctx, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("blocks regular user from writing SAML-synced field", func(t *testing.T) { + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: samlField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"sneaky"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "saml sync") + }) + + t.Run("allows regular user to upsert non-synced field", func(t *testing.T) { + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: nonSyncedField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"hello"`), + } + result, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.NoError(t, upsertErr) + assert.NotEmpty(t, result.ID) + }) + + t.Run("sync lock applies to batch upsert", func(t *testing.T) { + values := []*model.PropertyValue{ + { + GroupID: group.ID, + FieldID: ldapField.ID, + TargetID: targetID, + TargetType: "user", + Value: json.RawMessage(`"batch test"`), + }, + } + _, upsertErr := th.service.UpsertPropertyValues(th.Context, values) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "ldap sync") + + // Same batch with the right caller should succeed + rctx := RequestContextWithCallerID(th.Context, model.CallerIDLDAPSync) + results, upsertErr := th.service.UpsertPropertyValues(rctx, values) + require.NoError(t, upsertErr) + assert.Len(t, results, 1) + }) +} + func TestDeletePropertyValuesForField_WriteAccessControl(t *testing.T) { th := Setup(t).RegisterCPAPropertyGroup(t) th.service.setPluginCheckerForTests(func(pluginID string) bool { return pluginID == "plugin-1" }) @@ -2105,7 +2303,8 @@ func TestDeletePropertyValuesForField_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "field-delete-values", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -2154,7 +2353,8 @@ func TestDeletePropertyValuesForField_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "field-delete-values-fail", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), Attrs: model.StringInterface{ model.PropertyAttrsProtected: true, }, @@ -2193,7 +2393,8 @@ func TestDeletePropertyValuesForField_WriteAccessControl(t *testing.T) { GroupID: th.CPAGroupID, Name: "public-field-delete-values", Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, err := th.service.CreatePropertyField(rctxAnon, field) require.NoError(t, err) diff --git a/server/channels/app/properties/field_limit.go b/server/channels/app/properties/field_limit.go new file mode 100644 index 00000000000..491aa8ac234 --- /dev/null +++ b/server/channels/app/properties/field_limit.go @@ -0,0 +1,87 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "errors" + "fmt" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +var ( + ErrFieldLimitReached = errors.New("per-object-type field limit reached") + ErrGroupFieldLimitReached = errors.New("group field limit reached") +) + +// FieldLimitConfig defines limits for a specific property group. +type FieldLimitConfig struct { + // PerObjectType maps ObjectType values to their maximum field count. + // For example: {"user": 20} means at most 20 fields with ObjectType="user". + PerObjectType map[string]int64 + + // GlobalLimit is the maximum total number of fields across the entire group, + // regardless of ObjectType. Zero means no global limit. + GlobalLimit int64 +} + +// FieldLimitHook enforces per-group field creation limits. It checks both +// per-object-type limits and global group limits before allowing a field +// to be created. The hook only applies to groups that have been configured +// with limits. +type FieldLimitHook struct { + BasePropertyHook + propertyService *PropertyService + limits map[string]*FieldLimitConfig // groupID -> config +} + +var _ PropertyHook = (*FieldLimitHook)(nil) + +// NewFieldLimitHook creates a hook that enforces field limits. Call +// AddGroupLimit to configure limits for specific groups. +func NewFieldLimitHook(ps *PropertyService) *FieldLimitHook { + return &FieldLimitHook{ + propertyService: ps, + limits: make(map[string]*FieldLimitConfig), + } +} + +// AddGroupLimit registers a limit configuration for the given group ID. +func (h *FieldLimitHook) AddGroupLimit(groupID string, config *FieldLimitConfig) { + h.limits[groupID] = config +} + +func (h *FieldLimitHook) PreCreatePropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + config, ok := h.limits[field.GroupID] + if !ok { + return field, nil + } + + // Check per-object-type limit + if field.ObjectType != "" { + if limit, hasLimit := config.PerObjectType[field.ObjectType]; hasLimit { + count, err := h.propertyService.countActivePropertyFieldsForGroupObjectType(field.GroupID, field.ObjectType) + if err != nil { + return nil, fmt.Errorf("failed to count fields: %w", err) + } + if count >= limit { + return nil, fmt.Errorf("limit_reached: field limit of %d reached for object type %q: %w", limit, field.ObjectType, ErrFieldLimitReached) + } + } + } + + // Check global group limit + if config.GlobalLimit > 0 { + count, err := h.propertyService.countActivePropertyFieldsForGroup(field.GroupID) + if err != nil { + return nil, fmt.Errorf("failed to count group fields: %w", err) + } + if count >= config.GlobalLimit { + return nil, fmt.Errorf("group_limit_reached: global field limit of %d reached for group: %w", config.GlobalLimit, ErrGroupFieldLimitReached) + } + } + + return field, nil +} diff --git a/server/channels/app/properties/field_limit_test.go b/server/channels/app/properties/field_limit_test.go new file mode 100644 index 00000000000..ba6d59907b4 --- /dev/null +++ b/server/channels/app/properties/field_limit_test.go @@ -0,0 +1,84 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFieldLimitHook(t *testing.T) { + th := Setup(t) + + group, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_field_limit", Version: model.PropertyGroupVersionV2}) + require.NoError(t, err) + + hook := NewFieldLimitHook(th.service) + hook.AddGroupLimit(group.ID, &FieldLimitConfig{ + PerObjectType: map[string]int64{ + "user": 3, + }, + GlobalLimit: 5, + }) + th.service.AddHook(hook) + + makeField := func(objectType string) *model.PropertyField { + return &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: objectType, + } + } + + t.Run("allows fields up to per-object-type limit", func(t *testing.T) { + for range 3 { + _, createErr := th.service.CreatePropertyField(th.Context, makeField("user")) + require.NoError(t, createErr) + } + }) + + t.Run("rejects field at per-object-type limit", func(t *testing.T) { + _, createErr := th.service.CreatePropertyField(th.Context, makeField("user")) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "limit_reached") + }) + + t.Run("allows fields for different object type", func(t *testing.T) { + _, createErr := th.service.CreatePropertyField(th.Context, makeField("post")) + require.NoError(t, createErr) + }) + + t.Run("rejects at global limit", func(t *testing.T) { + // We have 3 user + 1 post = 4 fields. One more should succeed. + _, createErr := th.service.CreatePropertyField(th.Context, makeField("post")) + require.NoError(t, createErr) + + // Now at 5, should hit global limit + _, createErr = th.service.CreatePropertyField(th.Context, makeField("post")) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "group_limit_reached") + }) + + t.Run("skips limit check for unregistered groups", func(t *testing.T) { + otherGroup, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_no_limits", Version: model.PropertyGroupVersionV2}) + require.NoError(t, groupErr) + + for range 10 { + field := &model.PropertyField{ + GroupID: otherGroup.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + _, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + } + }) +} diff --git a/server/channels/app/properties/helper_test.go b/server/channels/app/properties/helper_test.go index 35825db767f..fe95d9c3c80 100644 --- a/server/channels/app/properties/helper_test.go +++ b/server/channels/app/properties/helper_test.go @@ -48,10 +48,6 @@ func setupTestHelper(s store.Store, tb testing.TB) *TestHelper { }) require.NoError(tb, err) - // Create and wire the PropertyAccessService - pas := NewPropertyAccessService(service, nil) - service.SetPropertyAccessService(pas) - tb.Cleanup(func() { s.Close() }) @@ -69,12 +65,29 @@ func RequestContextWithCallerID(rctx request.CTX, callerID string) request.CTX { return rctx.WithContext(ctx) } +// setPluginCheckerForTests sets the plugin checker on the AccessControlHook for testing. +func (ps *PropertyService) setPluginCheckerForTests(pluginChecker PluginChecker) { + for _, hook := range ps.hooks { + if ach, ok := hook.(*AccessControlHook); ok { + ach.setPluginCheckerForTests(pluginChecker) + } + } +} + +func (h *AccessControlHook) setPluginCheckerForTests(pluginChecker PluginChecker) { + h.pluginChecker = pluginChecker +} + func (th *TestHelper) RegisterCPAPropertyGroup(tb testing.TB) *TestHelper { // Register the CPA group so requiresAccessControl can always look it up - group, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: model.CustomProfileAttributesPropertyGroupName, Version: model.PropertyGroupVersionV1}) + group, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: model.AccessControlPropertyGroupName, Version: model.PropertyGroupVersionV2}) require.NoError(tb, groupErr) th.CPAGroupID = group.ID + // Create and register the access control hook now that the group ID is known + hook := NewAccessControlHook(th.service, nil, group.ID) + th.service.AddHook(hook) + return th } diff --git a/server/channels/app/properties/hooks.go b/server/channels/app/properties/hooks.go new file mode 100644 index 00000000000..9e62c5956df --- /dev/null +++ b/server/channels/app/properties/hooks.go @@ -0,0 +1,455 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +// errNilHookResult is returned when a pre-hook returns a nil result without an +// error. This catches buggy hook implementations early rather than letting a +// nil propagate into the store layer. +var ( + errNilHookResult = errors.New("property hook returned nil result") + errFieldCardinalityBroken = errors.New("PostGetPropertyFields hook returned fewer fields than it received") +) + +// PropertyHook defines an interface for hooks that run before and after property +// service operations. Hooks can inspect and modify inputs (pre-hooks) or filter +// outputs (post-hooks). A pre-hook returns an error to block the operation; a +// post-hook returns an error to suppress the result. Returning nil means the +// hook has no objection and the operation may proceed. +// +// Pre-hooks receive the operation's input parameters and may return modified +// versions. Post-hooks receive the operation's results and may return filtered +// or modified versions. +// +// Multiple hooks are called in registration order. Each hook receives the +// output of the previous hook (or the original input for the first hook). +type PropertyHook interface { + // Field pre-hooks (write operations) + + PreCreatePropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) + PreUpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) + PreUpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) + PreDeletePropertyField(rctx request.CTX, groupID string, id string) error + + // PostUpdatePropertyFields runs after a successful field update (including + // the linked-field propagation pass). It receives the pre-update state of + // the requested fields (parallel to requested), the post-update requested + // fields, and the post-update propagated fields. Hooks may transform attrs + // on either bucket (e.g. redact information for the caller); the + // dispatcher enforces cardinality preservation on both buckets so a buggy + // hook that drops fields surfaces an error rather than silently truncating + // the broadcast. Returns the IDs of fields whose dependent property values + // were cleared as a side effect (e.g. type-change cleanup); the caller + // publishes the corresponding WS events. Errors are best-effort: the + // dispatcher logs and continues, the update is not rolled back. + PostUpdatePropertyFields(rctx request.CTX, groupID string, prev, requested, propagated []*model.PropertyField) (newRequested, newPropagated []*model.PropertyField, clearedFieldIDs []string, err error) + + // Field pre-hook for count operations. Count operations return only a + // scalar so there is no post-hook — access control applied to per-row + // data does not apply, but license/group-level gating still does. + // Return an error to block the count. + PreCountPropertyFields(rctx request.CTX, groupID string) error + + // Field post-hooks (read operations) + // + // PostGetPropertyField is called after retrieving a single field (by ID or by name). + // Implementations must return a non-nil field; returning nil is treated as a + // hook bug and the dispatcher surfaces errNilHookResult. To block a caller + // from seeing a field, return a sentinel error instead. + PostGetPropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) + // PostGetPropertyFields is called after retrieving multiple fields (by IDs or search). + // Implementations must preserve slice length — the dispatcher enforces this and will + // return an error if a hook returns fewer fields than it received. + PostGetPropertyFields(rctx request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) + + // Value pre-hooks (write operations) + + PreCreatePropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) + PreCreatePropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) + PreUpdatePropertyValue(rctx request.CTX, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) + PreUpdatePropertyValues(rctx request.CTX, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) + PreUpsertPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) + PreUpsertPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) + PreDeletePropertyValue(rctx request.CTX, groupID string, id string) error + PreDeletePropertyValuesForTarget(rctx request.CTX, groupID string, targetType string, targetID string) error + PreDeletePropertyValuesForField(rctx request.CTX, groupID string, fieldID string) error + + // Value post-hooks (read operations) + // + // PostGetPropertyValue is called after retrieving a single value. + // Return nil value to indicate the value is not accessible. + PostGetPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) + // PostGetPropertyValues is called after retrieving multiple values (by IDs or search). + // Implementations may remove entries from the returned slice. + PostGetPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) +} + +// BasePropertyHook provides default passthrough implementations for every +// PropertyHook method. Embed it in concrete hooks to only override the +// methods you care about. +type BasePropertyHook struct{} + +func (BasePropertyHook) PreCreatePropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + return field, nil +} +func (BasePropertyHook) PreUpdatePropertyField(_ request.CTX, _ string, field *model.PropertyField) (*model.PropertyField, error) { + return field, nil +} +func (BasePropertyHook) PreUpdatePropertyFields(_ request.CTX, _ string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + return fields, nil +} +func (BasePropertyHook) PreDeletePropertyField(_ request.CTX, _ string, _ string) error { + return nil +} +func (BasePropertyHook) PostUpdatePropertyFields(_ request.CTX, _ string, _, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + return requested, propagated, nil, nil +} +func (BasePropertyHook) PreCountPropertyFields(_ request.CTX, _ string) error { + return nil +} +func (BasePropertyHook) PostGetPropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + return field, nil +} +func (BasePropertyHook) PostGetPropertyFields(_ request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) { + return fields, nil +} +func (BasePropertyHook) PreCreatePropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, nil +} +func (BasePropertyHook) PreCreatePropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + return values, nil +} +func (BasePropertyHook) PreUpdatePropertyValue(_ request.CTX, _ string, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, nil +} +func (BasePropertyHook) PreUpdatePropertyValues(_ request.CTX, _ string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + return values, nil +} +func (BasePropertyHook) PreUpsertPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, nil +} +func (BasePropertyHook) PreUpsertPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + return values, nil +} +func (BasePropertyHook) PreDeletePropertyValue(_ request.CTX, _ string, _ string) error { + return nil +} +func (BasePropertyHook) PreDeletePropertyValuesForTarget(_ request.CTX, _ string, _ string, _ string) error { + return nil +} +func (BasePropertyHook) PreDeletePropertyValuesForField(_ request.CTX, _ string, _ string) error { + return nil +} +func (BasePropertyHook) PostGetPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, nil +} +func (BasePropertyHook) PostGetPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + return values, nil +} + +// AddHook registers a hook with the property service. Hooks are called in +// registration order for each operation. +func (ps *PropertyService) AddHook(hook PropertyHook) { + ps.hooks = append(ps.hooks, hook) +} + +// runPreCreatePropertyField runs all registered pre-hooks for CreatePropertyField. +func (ps *PropertyService) runPreCreatePropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + var err error + for _, hook := range ps.hooks { + field, err = hook.PreCreatePropertyField(rctx, field) + if err != nil { + return nil, err + } + if field == nil { + return nil, errNilHookResult + } + } + return field, nil +} + +// runPreUpdatePropertyField runs all registered pre-hooks for UpdatePropertyField. +func (ps *PropertyService) runPreUpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { + var err error + for _, hook := range ps.hooks { + field, err = hook.PreUpdatePropertyField(rctx, groupID, field) + if err != nil { + return nil, err + } + if field == nil { + return nil, errNilHookResult + } + } + return field, nil +} + +// runPreUpdatePropertyFields runs all registered pre-hooks for UpdatePropertyFields. +func (ps *PropertyService) runPreUpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + var err error + for _, hook := range ps.hooks { + fields, err = hook.PreUpdatePropertyFields(rctx, groupID, fields) + if err != nil { + return nil, err + } + if fields == nil { + return nil, errNilHookResult + } + } + return fields, nil +} + +// runPostUpdatePropertyFields runs all registered post-hooks for +// UpdatePropertyFields. Each hook may transform the requested and propagated +// buckets in place (e.g. redaction); the dispatcher chains the transformed +// slices through subsequent hooks and enforces cardinality preservation on +// both buckets so a buggy hook that drops fields surfaces an error rather +// than silently truncating the broadcast. The cleared field IDs returned by +// each hook are deduped into a single slice. Best-effort: hook errors and +// cardinality violations are logged and skipped (the offending hook's +// transform is dropped for the chain, but the update itself is not rolled +// back). +func (ps *PropertyService) runPostUpdatePropertyFields(rctx request.CTX, groupID string, prev, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string) { + seen := map[string]struct{}{} + var cleared []string + for _, hook := range ps.hooks { + newRequested, newPropagated, ids, err := hook.PostUpdatePropertyFields(rctx, groupID, prev, requested, propagated) + if err != nil { + rctx.Logger().Error("PostUpdatePropertyFields hook failed", + mlog.String("group_id", groupID), + mlog.Err(err), + ) + continue + } + if len(newRequested) != len(requested) || len(newPropagated) != len(propagated) { + rctx.Logger().Error("PostUpdatePropertyFields hook returned wrong-length slice", + mlog.String("group_id", groupID), + mlog.Err(errFieldCardinalityBroken), + ) + continue + } + requested = newRequested + propagated = newPropagated + for _, id := range ids { + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + cleared = append(cleared, id) + } + } + return requested, propagated, cleared +} + +// runPreDeletePropertyField runs all registered pre-hooks for DeletePropertyField. +func (ps *PropertyService) runPreDeletePropertyField(rctx request.CTX, groupID string, id string) error { + for _, hook := range ps.hooks { + if err := hook.PreDeletePropertyField(rctx, groupID, id); err != nil { + return err + } + } + return nil +} + +// runPreCountPropertyFields runs all registered pre-hooks for the public +// CountProperty* methods. +func (ps *PropertyService) runPreCountPropertyFields(rctx request.CTX, groupID string) error { + for _, hook := range ps.hooks { + if err := hook.PreCountPropertyFields(rctx, groupID); err != nil { + return err + } + } + return nil +} + +// runPostGetPropertyField runs all registered post-hooks for single field retrieval. +func (ps *PropertyService) runPostGetPropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if field == nil { + return nil, nil + } + var err error + for _, hook := range ps.hooks { + field, err = hook.PostGetPropertyField(rctx, field) + if err != nil { + return nil, err + } + if field == nil { + return nil, errNilHookResult + } + } + return field, nil +} + +// runPostGetPropertyFields runs all registered post-hooks for multi-field retrieval. +// It enforces that hooks preserve slice length — a hook that drops fields is a bug. +func (ps *PropertyService) runPostGetPropertyFields(rctx request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) { + var err error + for _, hook := range ps.hooks { + n := len(fields) + fields, err = hook.PostGetPropertyFields(rctx, fields) + if err != nil { + return nil, err + } + if len(fields) != n { + return nil, errFieldCardinalityBroken + } + } + return fields, nil +} + +// runPreCreatePropertyValue runs all registered pre-hooks for CreatePropertyValue. +func (ps *PropertyService) runPreCreatePropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + value, err = hook.PreCreatePropertyValue(rctx, value) + if err != nil { + return nil, err + } + if value == nil { + return nil, errNilHookResult + } + } + return value, nil +} + +// runPreCreatePropertyValues runs all registered pre-hooks for CreatePropertyValues. +func (ps *PropertyService) runPreCreatePropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + values, err = hook.PreCreatePropertyValues(rctx, values) + if err != nil { + return nil, err + } + if values == nil { + return nil, errNilHookResult + } + } + return values, nil +} + +// runPreUpdatePropertyValue runs all registered pre-hooks for UpdatePropertyValue. +func (ps *PropertyService) runPreUpdatePropertyValue(rctx request.CTX, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + value, err = hook.PreUpdatePropertyValue(rctx, groupID, value) + if err != nil { + return nil, err + } + if value == nil { + return nil, errNilHookResult + } + } + return value, nil +} + +// runPreUpdatePropertyValues runs all registered pre-hooks for UpdatePropertyValues. +func (ps *PropertyService) runPreUpdatePropertyValues(rctx request.CTX, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + values, err = hook.PreUpdatePropertyValues(rctx, groupID, values) + if err != nil { + return nil, err + } + if values == nil { + return nil, errNilHookResult + } + } + return values, nil +} + +// runPreUpsertPropertyValue runs all registered pre-hooks for UpsertPropertyValue. +func (ps *PropertyService) runPreUpsertPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + value, err = hook.PreUpsertPropertyValue(rctx, value) + if err != nil { + return nil, err + } + if value == nil { + return nil, errNilHookResult + } + } + return value, nil +} + +// runPreUpsertPropertyValues runs all registered pre-hooks for UpsertPropertyValues. +func (ps *PropertyService) runPreUpsertPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + values, err = hook.PreUpsertPropertyValues(rctx, values) + if err != nil { + return nil, err + } + if values == nil { + return nil, errNilHookResult + } + } + return values, nil +} + +// runPreDeletePropertyValue runs all registered pre-hooks for DeletePropertyValue. +func (ps *PropertyService) runPreDeletePropertyValue(rctx request.CTX, groupID string, id string) error { + for _, hook := range ps.hooks { + if err := hook.PreDeletePropertyValue(rctx, groupID, id); err != nil { + return err + } + } + return nil +} + +// runPreDeletePropertyValuesForTarget runs all registered pre-hooks for DeletePropertyValuesForTarget. +func (ps *PropertyService) runPreDeletePropertyValuesForTarget(rctx request.CTX, groupID string, targetType string, targetID string) error { + for _, hook := range ps.hooks { + if err := hook.PreDeletePropertyValuesForTarget(rctx, groupID, targetType, targetID); err != nil { + return err + } + } + return nil +} + +// runPreDeletePropertyValuesForField runs all registered pre-hooks for DeletePropertyValuesForField. +func (ps *PropertyService) runPreDeletePropertyValuesForField(rctx request.CTX, groupID string, fieldID string) error { + for _, hook := range ps.hooks { + if err := hook.PreDeletePropertyValuesForField(rctx, groupID, fieldID); err != nil { + return err + } + } + return nil +} + +// runPostGetPropertyValue runs all registered post-hooks for single value retrieval. +func (ps *PropertyService) runPostGetPropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if value == nil { + return nil, nil + } + var err error + for _, hook := range ps.hooks { + value, err = hook.PostGetPropertyValue(rctx, value) + if err != nil { + return nil, err + } + if value == nil { + return nil, nil + } + } + return value, nil +} + +// runPostGetPropertyValues runs all registered post-hooks for multi-value retrieval. +func (ps *PropertyService) runPostGetPropertyValues(rctx request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + var err error + for _, hook := range ps.hooks { + values, err = hook.PostGetPropertyValues(rctx, values) + if err != nil { + return nil, err + } + } + return values, nil +} diff --git a/server/channels/app/properties/hooks_test.go b/server/channels/app/properties/hooks_test.go new file mode 100644 index 00000000000..efa52e1f778 --- /dev/null +++ b/server/channels/app/properties/hooks_test.go @@ -0,0 +1,637 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "fmt" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testHook is a configurable PropertyHook implementation for testing hook +// registration, ordering, chaining, and blocking behavior. It embeds +// BasePropertyHook for default passthrough behavior and only overrides +// methods where a test-specific function is set. +type testHook struct { + BasePropertyHook + preCreateFieldFn func(*model.PropertyField) (*model.PropertyField, error) + preUpdateFieldFn func(string, *model.PropertyField) (*model.PropertyField, error) + preUpdateFieldsFn func(string, []*model.PropertyField) ([]*model.PropertyField, error) + preDeleteFieldFn func(string, string) error + postUpdateFieldsFn func(string, []*model.PropertyField, []*model.PropertyField, []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) + postGetFieldFn func(*model.PropertyField) (*model.PropertyField, error) + postGetFieldsFn func([]*model.PropertyField) ([]*model.PropertyField, error) + preUpsertValueFn func(*model.PropertyValue) (*model.PropertyValue, error) + preUpsertValuesFn func([]*model.PropertyValue) ([]*model.PropertyValue, error) + postGetValueFn func(*model.PropertyValue) (*model.PropertyValue, error) + postGetValuesFn func([]*model.PropertyValue) ([]*model.PropertyValue, error) +} + +func (h *testHook) PreCreatePropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if h.preCreateFieldFn != nil { + return h.preCreateFieldFn(field) + } + return field, nil +} + +func (h *testHook) PreUpdatePropertyField(_ request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { + if h.preUpdateFieldFn != nil { + return h.preUpdateFieldFn(groupID, field) + } + return field, nil +} + +func (h *testHook) PreUpdatePropertyFields(_ request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if h.preUpdateFieldsFn != nil { + return h.preUpdateFieldsFn(groupID, fields) + } + return fields, nil +} + +func (h *testHook) PreDeletePropertyField(_ request.CTX, groupID string, id string) error { + if h.preDeleteFieldFn != nil { + return h.preDeleteFieldFn(groupID, id) + } + return nil +} + +func (h *testHook) PostUpdatePropertyFields(_ request.CTX, groupID string, prev, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + if h.postUpdateFieldsFn != nil { + return h.postUpdateFieldsFn(groupID, prev, requested, propagated) + } + return requested, propagated, nil, nil +} + +func (h *testHook) PostGetPropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + if h.postGetFieldFn != nil { + return h.postGetFieldFn(field) + } + return field, nil +} + +func (h *testHook) PostGetPropertyFields(_ request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if h.postGetFieldsFn != nil { + return h.postGetFieldsFn(fields) + } + return fields, nil +} + +func (h *testHook) PreUpsertPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if h.preUpsertValueFn != nil { + return h.preUpsertValueFn(value) + } + return value, nil +} + +func (h *testHook) PreUpsertPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if h.preUpsertValuesFn != nil { + return h.preUpsertValuesFn(values) + } + return values, nil +} + +func (h *testHook) PostGetPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if h.postGetValueFn != nil { + return h.postGetValueFn(value) + } + return value, nil +} + +func (h *testHook) PostGetPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if h.postGetValuesFn != nil { + return h.postGetValuesFn(values) + } + return values, nil +} + +func TestHookRegistration(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + + t.Run("service starts with no hooks", func(t *testing.T) { + service, err := New(ServiceConfig{ + PropertyGroupStore: th.dbStore.PropertyGroup(), + PropertyFieldStore: th.dbStore.PropertyField(), + PropertyValueStore: th.dbStore.PropertyValue(), + }) + require.NoError(t, err) + assert.Empty(t, service.hooks) + }) + + t.Run("AddHook appends hooks in order", func(t *testing.T) { + service, err := New(ServiceConfig{ + PropertyGroupStore: th.dbStore.PropertyGroup(), + PropertyFieldStore: th.dbStore.PropertyField(), + PropertyValueStore: th.dbStore.PropertyValue(), + }) + require.NoError(t, err) + + hook1 := &testHook{} + hook2 := &testHook{} + service.AddHook(hook1) + service.AddHook(hook2) + assert.Len(t, service.hooks, 2) + }) +} + +func TestPreHookBlocking(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("pre-hook error blocks CreatePropertyField", func(t *testing.T) { + hook := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + return nil, fmt.Errorf("blocked by hook") + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field := &model.PropertyField{ + GroupID: groupID, + Name: "blocked-field-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + } + _, err := th.service.CreatePropertyField(rctx, field) + require.Error(t, err) + assert.Contains(t, err.Error(), "blocked by hook") + }) + + t.Run("pre-hook error blocks DeletePropertyField", func(t *testing.T) { + hook := &testHook{ + preDeleteFieldFn: func(gid string, id string) error { + return fmt.Errorf("delete blocked") + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + err := th.service.DeletePropertyField(rctx, groupID, model.NewId()) + require.Error(t, err) + assert.Contains(t, err.Error(), "delete blocked") + }) + + t.Run("pre-hook error blocks UpsertPropertyValue", func(t *testing.T) { + hook := &testHook{ + preUpsertValueFn: func(value *model.PropertyValue) (*model.PropertyValue, error) { + return nil, fmt.Errorf("upsert blocked") + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + value := &model.PropertyValue{ + GroupID: groupID, + FieldID: model.NewId(), + } + _, err := th.service.UpsertPropertyValue(rctx, value) + require.Error(t, err) + assert.Contains(t, err.Error(), "upsert blocked") + }) +} + +func TestPreHookInputModification(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("pre-hook modifies field before creation", func(t *testing.T) { + hook := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + // Modify the field name + field.Name = "modified-" + field.Name + return field, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field := &model.PropertyField{ + GroupID: groupID, + Name: "original", + Type: model.PropertyFieldTypeText, + TargetType: "user", + } + result, err := th.service.CreatePropertyField(rctx, field) + require.NoError(t, err) + assert.Equal(t, "modified-original", result.Name) + }) +} + +func TestPostHookFiltering(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("post-hook returning nil field without error surfaces errNilHookResult", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "nil-return-field-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + + hook := &testHook{ + postGetFieldFn: func(f *model.PropertyField) (*model.PropertyField, error) { + return nil, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + result, err := th.service.GetPropertyField(rctx, groupID, field.ID) + require.ErrorIs(t, err, errNilHookResult) + assert.Nil(t, result) + }) + + t.Run("post-hook that drops fields from list returns error", func(t *testing.T) { + field1 := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "keep-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + field2 := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "remove-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + + hook := &testHook{ + postGetFieldsFn: func(fields []*model.PropertyField) ([]*model.PropertyField, error) { + filtered := []*model.PropertyField{} + for _, f := range fields { + if f.ID == field1.ID { + filtered = append(filtered, f) + } + } + return filtered, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + _, err := th.service.GetPropertyFields(rctx, groupID, []string{field1.ID, field2.ID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "fewer fields") + }) +} + +func TestMultipleHooksChaining(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("multiple pre-hooks chain modifications in order", func(t *testing.T) { + order := []string{} + + hook1 := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + order = append(order, "hook1") + field.Name = field.Name + "-h1" + return field, nil + }, + } + hook2 := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + order = append(order, "hook2") + field.Name = field.Name + "-h2" + return field, nil + }, + } + th.service.AddHook(hook1) + th.service.AddHook(hook2) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-2] }() + + field := &model.PropertyField{ + GroupID: groupID, + Name: "base", + Type: model.PropertyFieldTypeText, + TargetType: "user", + } + result, err := th.service.CreatePropertyField(rctx, field) + require.NoError(t, err) + assert.Equal(t, "base-h1-h2", result.Name) + assert.Equal(t, []string{"hook1", "hook2"}, order) + }) + + t.Run("first hook error prevents second hook from running", func(t *testing.T) { + hook2Called := false + + hook1 := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + return nil, fmt.Errorf("hook1 blocked") + }, + } + hook2 := &testHook{ + preCreateFieldFn: func(field *model.PropertyField) (*model.PropertyField, error) { + hook2Called = true + return field, nil + }, + } + th.service.AddHook(hook1) + th.service.AddHook(hook2) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-2] }() + + field := &model.PropertyField{ + GroupID: groupID, + Name: "should-fail-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + } + _, err := th.service.CreatePropertyField(rctx, field) + require.Error(t, err) + assert.False(t, hook2Called, "second hook should not have been called") + }) + + t.Run("multiple post-hooks chain in order", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "chain-post-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + Attrs: model.StringInterface{"step": "0"}, + }) + + hook1 := &testHook{ + postGetFieldFn: func(f *model.PropertyField) (*model.PropertyField, error) { + if f.Attrs == nil { + f.Attrs = make(model.StringInterface) + } + f.Attrs["hook1"] = true + return f, nil + }, + } + hook2 := &testHook{ + postGetFieldFn: func(f *model.PropertyField) (*model.PropertyField, error) { + if f.Attrs == nil { + f.Attrs = make(model.StringInterface) + } + f.Attrs["hook2"] = true + return f, nil + }, + } + th.service.AddHook(hook1) + th.service.AddHook(hook2) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-2] }() + + result, err := th.service.GetPropertyField(rctx, groupID, field.ID) + require.NoError(t, err) + assert.Equal(t, true, result.Attrs["hook1"]) + assert.Equal(t, true, result.Attrs["hook2"]) + }) +} + +func TestAccessControlHookGroupScoping(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + + th.service.setPluginCheckerForTests(func(pluginID string) bool { + return pluginID == "plugin-1" + }) + + rctxPlugin1 := RequestContextWithCallerID(th.Context, "plugin-1") + rctxPlugin2 := RequestContextWithCallerID(th.Context, "plugin-2") + + t.Run("access control enforced for managed group (CPA)", func(t *testing.T) { + // Create a protected field in the CPA group via the source plugin + field := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "protected-managed-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + created, err := th.service.CreatePropertyField(rctxPlugin1, field) + require.NoError(t, err) + assert.Equal(t, "plugin-1", created.Attrs[model.PropertyAttrsSourcePluginID]) + + // Another plugin should NOT be able to update it (protected) + created.Attrs[model.PropertyAttrsProtected] = true + updated, _, err := th.service.UpdatePropertyField(rctxPlugin1, th.CPAGroupID, created) + require.NoError(t, err) + + updated.Name = "attempt-update" + _, _, err = th.service.UpdatePropertyField(rctxPlugin2, th.CPAGroupID, updated) + require.Error(t, err) + assert.Contains(t, err.Error(), "protected") + }) + + t.Run("access control NOT enforced for unmanaged group", func(t *testing.T) { + unmanagedGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "unmanaged_scoping_test", Version: model.PropertyGroupVersionV1}) + require.NoError(t, err) + + // Create a protected field in an unmanaged group + field := &model.PropertyField{ + GroupID: unmanagedGroup.ID, + Name: "protected-unmanaged-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + Attrs: model.StringInterface{ + model.PropertyAttrsProtected: true, + model.PropertyAttrsSourcePluginID: "plugin-1", + }, + } + created, err := th.service.CreatePropertyField(rctxPlugin1, field) + require.NoError(t, err) + + // Another plugin CAN update it (no access control for this group) + created.Name = "updated-by-plugin2" + updated, _, err := th.service.UpdatePropertyField(rctxPlugin2, unmanagedGroup.ID, created) + require.NoError(t, err) + assert.Equal(t, "updated-by-plugin2", updated.Name) + }) + + t.Run("read filtering applied for managed group", func(t *testing.T) { + // Create a source-only protected field in the CPA group + field := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "source-only-managed-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, + model.PropertyAttrsProtected: true, + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": "opt1", "value": "Option 1"}, + map[string]any{"id": "opt2", "value": "Option 2"}, + }, + }, + } + created, err := th.service.CreatePropertyField(rctxPlugin1, field) + require.NoError(t, err) + + // Source plugin sees all options + result, err := th.service.GetPropertyField(rctxPlugin1, th.CPAGroupID, created.ID) + require.NoError(t, err) + opts := result.Attrs[model.PropertyFieldAttributeOptions].([]any) + assert.Len(t, opts, 2) + + // Other caller sees empty options + result2, err := th.service.GetPropertyField(rctx, th.CPAGroupID, created.ID) + require.NoError(t, err) + opts2 := result2.Attrs[model.PropertyFieldAttributeOptions].([]any) + assert.Len(t, opts2, 0) + }) + + t.Run("read filtering NOT applied for unmanaged group", func(t *testing.T) { + unmanagedGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "unmanaged_read_test", Version: model.PropertyGroupVersionV1}) + require.NoError(t, err) + + // Create a source-only field in an unmanaged group + field := &model.PropertyField{ + GroupID: unmanagedGroup.ID, + Name: "source-only-unmanaged-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + TargetType: "user", + Attrs: model.StringInterface{ + model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, + model.PropertyAttrsSourcePluginID: "plugin-1", + model.PropertyFieldAttributeOptions: []any{ + map[string]any{"id": "opt1", "value": "Option 1"}, + map[string]any{"id": "opt2", "value": "Option 2"}, + }, + }, + } + created, err := th.service.CreatePropertyField(rctxPlugin1, field) + require.NoError(t, err) + + // Non-source caller sees ALL options (no filtering for unmanaged groups) + result, err := th.service.GetPropertyField(rctx, unmanagedGroup.ID, created.ID) + require.NoError(t, err) + opts := result.Attrs[model.PropertyFieldAttributeOptions].([]any) + assert.Len(t, opts, 2) + }) +} + +func TestPreUpdatePropertyFieldsHook(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("pre-hook error blocks batch UpdatePropertyFields", func(t *testing.T) { + hook := &testHook{ + preUpdateFieldsFn: func(gid string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + return nil, fmt.Errorf("batch update blocked") + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "batch-block-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + field.Name = "updated" + _, _, _, err := th.service.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field}) + require.Error(t, err) + assert.Contains(t, err.Error(), "batch update blocked") + }) + + t.Run("pre-hook modifies fields in batch update", func(t *testing.T) { + hook := &testHook{ + preUpdateFieldsFn: func(gid string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + for _, f := range fields { + f.Name = "modified-" + f.Name + } + return fields, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field1 := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "batch-a-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + field2 := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "batch-b-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + + field1.Name = "a" + field2.Name = "b" + results, _, _, err := th.service.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field1, field2}) + require.NoError(t, err) + require.Len(t, results, 2) + assert.Equal(t, "modified-a", results[0].Name) + assert.Equal(t, "modified-b", results[1].Name) + }) +} + +func TestPostUpdatePropertyFieldsHook(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + rctx := th.Context + groupID := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1).ID + + t.Run("post-hook transforms requested attrs and surfaces cleared IDs", func(t *testing.T) { + hook := &testHook{ + postUpdateFieldsFn: func(_ string, _, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + for _, f := range requested { + if f.Attrs == nil { + f.Attrs = model.StringInterface{} + } + f.Attrs["redacted"] = true + } + ids := make([]string, 0, len(requested)) + for _, f := range requested { + ids = append(ids, f.ID) + } + return requested, propagated, ids, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "post-transform-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + field.Name = "post-transform-renamed-" + model.NewId() + + results, _, cleared, err := th.service.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field}) + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, true, results[0].Attrs["redacted"], "post-hook attr transform must reach caller") + assert.Equal(t, []string{field.ID}, cleared, "cleared IDs from post-hook must be surfaced") + }) + + t.Run("post-hook returning wrong-length requested slice is skipped", func(t *testing.T) { + hook := &testHook{ + postUpdateFieldsFn: func(_ string, _, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + // Drop a field — cardinality guard must reject this transform. + return requested[:0], propagated, nil, nil + }, + } + th.service.AddHook(hook) + defer func() { th.service.hooks = th.service.hooks[:len(th.service.hooks)-1] }() + + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: groupID, + Name: "post-cardinality-" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "user", + }) + field.Name = "post-cardinality-renamed-" + model.NewId() + + results, _, _, err := th.service.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field}) + require.NoError(t, err) + assert.Len(t, results, 1, "wrong-length transform must be discarded; original requested must survive") + }) +} diff --git a/server/channels/app/properties/license_check.go b/server/channels/app/properties/license_check.go new file mode 100644 index 00000000000..6ede3a3cdd0 --- /dev/null +++ b/server/channels/app/properties/license_check.go @@ -0,0 +1,148 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +var ErrLicenseRequired = errors.New("license_error: an Enterprise license is required") + +// LicenseProvider is a function that returns the current license. +type LicenseProvider func() *model.License + +// LicenseCheckHook enforces license requirements for property operations on +// specific groups. Operations on groups without a license requirement pass +// through without checks. +type LicenseCheckHook struct { + BasePropertyHook + licenseProvider LicenseProvider + managedGroupIDs map[string]struct{} +} + +var _ PropertyHook = (*LicenseCheckHook)(nil) + +// NewLicenseCheckHook creates a hook that requires an Enterprise license for +// all field and value operations on the given property groups. +func NewLicenseCheckHook(licenseProvider LicenseProvider, managedGroupIDs ...string) *LicenseCheckHook { + ids := make(map[string]struct{}, len(managedGroupIDs)) + for _, id := range managedGroupIDs { + ids[id] = struct{}{} + } + return &LicenseCheckHook{ + licenseProvider: licenseProvider, + managedGroupIDs: ids, + } +} + +// requireLicense returns ErrLicenseRequired when groupID is in the managed set +// and no Enterprise license is active. Unmanaged groups and licensed calls +// return nil. +func (h *LicenseCheckHook) requireLicense(groupID string) error { + if _, managed := h.managedGroupIDs[groupID]; !managed { + return nil + } + if !model.MinimumEnterpriseLicense(h.licenseProvider()) { + return ErrLicenseRequired + } + return nil +} + +// Field pre-hooks + +func (h *LicenseCheckHook) PreCreatePropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + return field, h.requireLicense(field.GroupID) +} + +func (h *LicenseCheckHook) PreUpdatePropertyField(_ request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { + return field, h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreUpdatePropertyFields(_ request.CTX, groupID string, fields []*model.PropertyField) ([]*model.PropertyField, error) { + return fields, h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreDeletePropertyField(_ request.CTX, groupID string, _ string) error { + return h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreCountPropertyFields(_ request.CTX, groupID string) error { + return h.requireLicense(groupID) +} + +// Field post-hooks + +func (h *LicenseCheckHook) PostGetPropertyField(_ request.CTX, field *model.PropertyField) (*model.PropertyField, error) { + return field, h.requireLicense(field.GroupID) +} + +func (h *LicenseCheckHook) PostGetPropertyFields(_ request.CTX, fields []*model.PropertyField) ([]*model.PropertyField, error) { + if len(fields) == 0 { + return fields, nil + } + return fields, h.requireLicense(fields[0].GroupID) +} + +// Value pre-hooks + +func (h *LicenseCheckHook) PreCreatePropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, h.requireLicense(value.GroupID) +} + +func (h *LicenseCheckHook) PreCreatePropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 { + return values, nil + } + return values, h.requireLicense(values[0].GroupID) +} + +func (h *LicenseCheckHook) PreUpdatePropertyValue(_ request.CTX, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreUpdatePropertyValues(_ request.CTX, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + return values, h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreUpsertPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + return value, h.requireLicense(value.GroupID) +} + +func (h *LicenseCheckHook) PreUpsertPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 { + return values, nil + } + return values, h.requireLicense(values[0].GroupID) +} + +func (h *LicenseCheckHook) PreDeletePropertyValue(_ request.CTX, groupID string, _ string) error { + return h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreDeletePropertyValuesForTarget(_ request.CTX, groupID string, _ string, _ string) error { + return h.requireLicense(groupID) +} + +func (h *LicenseCheckHook) PreDeletePropertyValuesForField(_ request.CTX, groupID string, _ string) error { + return h.requireLicense(groupID) +} + +// Value post-hooks + +func (h *LicenseCheckHook) PostGetPropertyValue(_ request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { + if value == nil { + return value, nil + } + return value, h.requireLicense(value.GroupID) +} + +func (h *LicenseCheckHook) PostGetPropertyValues(_ request.CTX, values []*model.PropertyValue) ([]*model.PropertyValue, error) { + if len(values) == 0 { + return values, nil + } + return values, h.requireLicense(values[0].GroupID) +} diff --git a/server/channels/app/properties/license_check_test.go b/server/channels/app/properties/license_check_test.go new file mode 100644 index 00000000000..a47254b6c70 --- /dev/null +++ b/server/channels/app/properties/license_check_test.go @@ -0,0 +1,140 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLicenseCheckHook(t *testing.T) { + th := Setup(t) + + group, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_license_check", Version: model.PropertyGroupVersionV2}) + require.NoError(t, err) + + var currentLicense *model.License + hook := NewLicenseCheckHook(func() *model.License { + return currentLicense + }, group.ID) + th.service.AddHook(hook) + + enterpriseLicense := model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise) + + makeField := func() *model.PropertyField { + return &model.PropertyField{ + GroupID: group.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + } + + t.Run("blocks field create without license", func(t *testing.T) { + currentLicense = nil + _, createErr := th.service.CreatePropertyField(th.Context, makeField()) + require.Error(t, createErr) + assert.Contains(t, createErr.Error(), "license_error") + }) + + t.Run("allows field create with license, blocks read after license loss", func(t *testing.T) { + currentLicense = enterpriseLicense + created, createErr := th.service.CreatePropertyField(th.Context, makeField()) + require.NoError(t, createErr) + assert.NotEmpty(t, created.ID) + + currentLicense = nil + _, getErr := th.service.GetPropertyField(th.Context, group.ID, created.ID) + require.Error(t, getErr) + assert.Contains(t, getErr.Error(), "license_error") + }) + + t.Run("blocks value upsert without license", func(t *testing.T) { + currentLicense = enterpriseLicense + field := th.CreatePropertyFieldDirect(t, makeField()) + + currentLicense = nil + value := &model.PropertyValue{ + GroupID: group.ID, + FieldID: field.ID, + TargetID: model.NewId(), + TargetType: "user", + Value: json.RawMessage(`"hello"`), + } + _, upsertErr := th.service.UpsertPropertyValue(th.Context, value) + require.Error(t, upsertErr) + assert.Contains(t, upsertErr.Error(), "license_error") + }) + + t.Run("allows operations on unmanaged groups without license", func(t *testing.T) { + currentLicense = nil + otherGroup, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "test_no_license_needed", Version: model.PropertyGroupVersionV2}) + require.NoError(t, groupErr) + + field := &model.PropertyField{ + GroupID: otherGroup.ID, + Name: "field_" + model.NewId(), + Type: model.PropertyFieldTypeText, + TargetType: "system", + ObjectType: "user", + } + created, createErr := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, createErr) + assert.NotEmpty(t, created.ID) + }) + + countCalls := []struct { + name string + call func(groupID string) error + }{ + {"CountActivePropertyFieldsForGroup", func(id string) error { + _, err := th.service.CountActivePropertyFieldsForGroup(th.Context, id) + return err + }}, + {"CountAllPropertyFieldsForGroup", func(id string) error { + _, err := th.service.CountAllPropertyFieldsForGroup(th.Context, id) + return err + }}, + {"CountActivePropertyFieldsForTarget", func(id string) error { + _, err := th.service.CountActivePropertyFieldsForTarget(th.Context, id, "user", model.NewId()) + return err + }}, + {"CountAllPropertyFieldsForTarget", func(id string) error { + _, err := th.service.CountAllPropertyFieldsForTarget(th.Context, id, "user", model.NewId()) + return err + }}, + } + + t.Run("blocks field counts without license on managed group", func(t *testing.T) { + currentLicense = enterpriseLicense + th.CreatePropertyFieldDirect(t, makeField()) + currentLicense = nil + for _, c := range countCalls { + err := c.call(group.ID) + require.Error(t, err, c.name) + assert.Contains(t, err.Error(), "license_error", c.name) + } + }) + + t.Run("allows field counts with license on managed group", func(t *testing.T) { + currentLicense = enterpriseLicense + for _, c := range countCalls { + require.NoError(t, c.call(group.ID), c.name) + } + }) + + t.Run("allows field counts without license on unmanaged group", func(t *testing.T) { + currentLicense = nil + otherGroup, groupErr := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "count_no_license_needed", Version: model.PropertyGroupVersionV2}) + require.NoError(t, groupErr) + for _, c := range countCalls { + require.NoError(t, c.call(otherGroup.ID), c.name) + } + }) +} diff --git a/server/channels/app/properties/migrations.go b/server/channels/app/properties/migrations.go index 78acc86574e..c522303f5f1 100644 --- a/server/channels/app/properties/migrations.go +++ b/server/channels/app/properties/migrations.go @@ -30,7 +30,7 @@ import ( // Returns the number of fields that were backfilled and the number that were // skipped, so the caller can log a summary. func (ps *PropertyService) MigrateBackfillCPADisplayName(rctx request.CTX) (backfilled int, skipped int, err error) { - group, err := ps.Group(model.CustomProfileAttributesPropertyGroupName) + group, err := ps.Group(model.AccessControlPropertyGroupName) if err != nil { return 0, 0, fmt.Errorf("MigrateBackfillCPADisplayName: failed to get CPA property group: %w", err) } @@ -74,7 +74,7 @@ func (ps *PropertyService) MigrateBackfillCPADisplayName(rctx request.CTX) (back // Use the unexported updatePropertyFields for the same reason as // searchPropertyFields above: the AC layer rejects writes from the // system to fields owned by a source plugin. - if _, _, updateErr := ps.updatePropertyFields(groupID, fieldsToUpdate); updateErr != nil { + if _, _, _, updateErr := ps.updatePropertyFields(rctx, groupID, fieldsToUpdate); updateErr != nil { return 0, 0, fmt.Errorf("MigrateBackfillCPADisplayName: failed to update CPA fields: %w", updateErr) } } diff --git a/server/channels/app/properties/property_field.go b/server/channels/app/properties/property_field.go index befc9ff1e9a..83e70179dc6 100644 --- a/server/channels/app/properties/property_field.go +++ b/server/channels/app/properties/property_field.go @@ -5,6 +5,7 @@ package properties import ( "context" + "errors" "fmt" "net/http" "reflect" @@ -43,10 +44,8 @@ func (ps *PropertyService) createPropertyField(field *model.PropertyField) (*mod return nil, err } - // FIXME: Legacy properties (PSAv1) skip conflict check, but - // template fields still need it because they can have linked - // dependents. - if field.IsPSAv1() && field.ObjectType != model.PropertyFieldObjectTypeTemplate { + // Legacy properties (PSAv1) skip the conflict check. + if field.IsPSAv1() { return ps.fieldStore.Create(field) } @@ -182,7 +181,15 @@ func (ps *PropertyService) getPropertyFieldFromMaster(groupID, id string) (*mode } func (ps *PropertyService) getPropertyFields(groupID string, ids []string) ([]*model.PropertyField, error) { - return ps.fieldStore.GetMany(context.Background(), groupID, ids) + fields, err := ps.fieldStore.GetMany(context.Background(), groupID, ids) + if err != nil { + var resultsMismatchErr *store.ErrResultsMismatch + if errors.As(err, &resultsMismatchErr) { + return nil, fmt.Errorf("%w: %w", ErrFieldNotFound, err) + } + return nil, err + } + return fields, nil } func (ps *PropertyService) getPropertyFieldByName(groupID, targetID, name string) (*model.PropertyField, error) { @@ -197,6 +204,10 @@ func (ps *PropertyService) countAllPropertyFieldsForGroup(groupID string) (int64 return ps.fieldStore.CountForGroup(groupID, true) } +func (ps *PropertyService) countActivePropertyFieldsForGroupObjectType(groupID, objectType string) (int64, error) { + return ps.fieldStore.CountForGroupObjectType(groupID, objectType, false) +} + func (ps *PropertyService) countActivePropertyFieldsForTarget(groupID, targetType, targetID string) (int64, error) { return ps.fieldStore.CountForTarget(groupID, targetType, targetID, false) } @@ -213,25 +224,25 @@ func (ps *PropertyService) searchPropertyFields(groupID string, opts model.Prope return ps.fieldStore.SearchPropertyFields(opts) } -func (ps *PropertyService) updatePropertyField(groupID string, field *model.PropertyField) (*model.PropertyField, error) { - fields, _, err := ps.updatePropertyFields(groupID, []*model.PropertyField{field}) +func (ps *PropertyService) updatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, []string, error) { + fields, _, clearedIDs, err := ps.updatePropertyFields(rctx, groupID, []*model.PropertyField{field}) if err != nil { - return nil, err + return nil, nil, err } - return fields[0], nil + return fields[0], clearedIDs, nil } -func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model.PropertyField) (requested []*model.PropertyField, propagated []*model.PropertyField, err error) { +func (ps *PropertyService) updatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) (requested []*model.PropertyField, propagated []*model.PropertyField, clearedFieldIDs []string, err error) { if len(fields) == 0 { - return nil, nil, nil + return nil, nil, nil, nil } // Fetch existing fields to compare for changes that require conflict check ids := make([]string, len(fields)) for i, f := range fields { if f == nil { - return nil, nil, fmt.Errorf("field at index %d is nil", i) + return nil, nil, nil, fmt.Errorf("field at index %d is nil", i) } ids[i] = f.ID } @@ -241,7 +252,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. // TOCTOU window that a replica read would leave open. existingFields, err := ps.fieldStore.GetMany(store.WithMaster(context.Background()), groupID, ids) if err != nil { - return nil, nil, fmt.Errorf("failed to get existing fields for update: %w", err) + return nil, nil, nil, fmt.Errorf("failed to get existing fields for update: %w", err) } // Build a map of existing fields by ID for quick lookup @@ -253,7 +264,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. // Enforce version match between field and group for each field for _, field := range fields { if err := ps.enforceFieldGroupVersionMatch("UpdatePropertyFields", groupID, field); err != nil { - return nil, nil, err + return nil, nil, nil, err } } @@ -264,16 +275,14 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. continue } - // FIXME: Legacy properties (PSAv1) skip conflict check, but - // template fields still need it because they can have linked - // dependents. - if field.IsPSAv1() && field.ObjectType != model.PropertyFieldObjectTypeTemplate { + // Legacy properties (PSAv1) skip the conflict check. + if field.IsPSAv1() { continue } // Block type changes on linked fields if existing.LinkedFieldID != nil && *existing.LinkedFieldID != "" && field.Type != existing.Type { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.linked_type_change.app_error", nil, @@ -284,7 +293,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. // Block options changes on linked fields if existing.LinkedFieldID != nil && *existing.LinkedFieldID != "" && optionsChanged(existing.Attrs, field.Attrs) { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.linked_options_change.app_error", nil, @@ -308,7 +317,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. newIsLinked := field.LinkedFieldID != nil if !existingIsLinked && newIsLinked { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.cannot_link_existing.app_error", nil, @@ -320,7 +329,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. // Block changing link target. To re-link, unlink first then create a // new linked field. if existingIsLinked && newIsLinked && *field.LinkedFieldID != *existing.LinkedFieldID { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.cannot_change_link_target.app_error", nil, @@ -333,11 +342,11 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. if field.Type != existing.Type { count, cErr := ps.fieldStore.CountLinkedFields(field.ID) if cErr != nil { - return nil, nil, fmt.Errorf("failed to count linked fields: %w", cErr) + return nil, nil, nil, fmt.Errorf("failed to count linked fields: %w", cErr) } if count > 0 { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.type_change_with_dependents.app_error", nil, @@ -357,11 +366,11 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. existing.ObjectType != field.ObjectType { conflictLevel, cErr := ps.fieldStore.CheckPropertyNameConflict(field, field.ID) if cErr != nil { - return nil, nil, fmt.Errorf("failed to check property name conflict: %w", cErr) + return nil, nil, nil, fmt.Errorf("failed to check property name conflict: %w", cErr) } if conflictLevel != "" { - return nil, nil, model.NewAppError( + return nil, nil, nil, model.NewAppError( "UpdatePropertyFields", "app.property_field.update.name_conflict.app_error", map[string]any{"Name": field.Name, "ConflictLevel": string(conflictLevel)}, @@ -384,7 +393,7 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. // options to linked dependents automatically via a JOIN-based UPDATE. all, uErr := ps.fieldStore.Update(groupID, fields, expectedUpdateAts) if uErr != nil { - return nil, nil, uErr + return nil, nil, nil, uErr } // Partition the returned fields into requested vs propagated by matching @@ -405,7 +414,18 @@ func (ps *PropertyService) updatePropertyFields(groupID string, fields []*model. } } - return requested, propagated, nil + // Run post-hooks. prev is parallel to requested. Hooks may transform + // either the requested or propagated bucket (e.g. attr redaction); the + // dispatcher enforces cardinality preservation on both buckets so a buggy + // hook that drops fields surfaces an error rather than silently truncating + // the broadcast. cleared IDs are unioned across hooks. + prev := make([]*model.PropertyField, 0, len(requested)) + for _, r := range requested { + prev = append(prev, existingByID[r.ID]) + } + requested, propagated, clearedFieldIDs = ps.runPostUpdatePropertyFields(rctx, groupID, prev, requested, propagated) + + return requested, propagated, clearedFieldIDs, nil } func (ps *PropertyService) deletePropertyField(groupID, id string) error { @@ -438,169 +458,114 @@ func (ps *PropertyService) deletePropertyField(groupID, id string) error { return ps.fieldStore.Delete(groupID, id) } -// Public routing methods +// Public methods func (ps *PropertyService) CreatePropertyField(rctx request.CTX, field *model.PropertyField) (*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(field.GroupID) + field, err := ps.runPreCreatePropertyField(rctx, field) if err != nil { return nil, fmt.Errorf("CreatePropertyField: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.CreatePropertyField(callerID, field) - } - return ps.createPropertyField(field) } func (ps *PropertyService) GetPropertyField(rctx request.CTX, groupID, id string) (*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + field, err := ps.getPropertyField(groupID, id) if err != nil { return nil, fmt.Errorf("GetPropertyField: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.GetPropertyField(callerID, groupID, id) - } - - return ps.getPropertyField(groupID, id) + return ps.runPostGetPropertyField(rctx, field) } func (ps *PropertyService) GetPropertyFields(rctx request.CTX, groupID string, ids []string) ([]*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + fields, err := ps.getPropertyFields(groupID, ids) if err != nil { return nil, fmt.Errorf("GetPropertyFields: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.GetPropertyFields(callerID, groupID, ids) - } - - return ps.getPropertyFields(groupID, ids) + return ps.runPostGetPropertyFields(rctx, fields) } func (ps *PropertyService) GetPropertyFieldByName(rctx request.CTX, groupID, targetID, name string) (*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + field, err := ps.getPropertyFieldByName(groupID, targetID, name) if err != nil { return nil, fmt.Errorf("GetPropertyFieldByName: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.GetPropertyFieldByName(callerID, groupID, targetID, name) - } - - return ps.getPropertyFieldByName(groupID, targetID, name) + return ps.runPostGetPropertyField(rctx, field) } func (ps *PropertyService) CountActivePropertyFieldsForGroup(rctx request.CTX, groupID string) (int64, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreCountPropertyFields(rctx, groupID); err != nil { return 0, fmt.Errorf("CountActivePropertyFieldsForGroup: %w", err) } - - if requiresAC { - return ps.propertyAccess.CountActivePropertyFieldsForGroup(groupID) - } - return ps.countActivePropertyFieldsForGroup(groupID) } func (ps *PropertyService) CountAllPropertyFieldsForGroup(rctx request.CTX, groupID string) (int64, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreCountPropertyFields(rctx, groupID); err != nil { return 0, fmt.Errorf("CountAllPropertyFieldsForGroup: %w", err) } - - if requiresAC { - return ps.propertyAccess.CountAllPropertyFieldsForGroup(groupID) - } - return ps.countAllPropertyFieldsForGroup(groupID) } func (ps *PropertyService) CountActivePropertyFieldsForTarget(rctx request.CTX, groupID, targetType, targetID string) (int64, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreCountPropertyFields(rctx, groupID); err != nil { return 0, fmt.Errorf("CountActivePropertyFieldsForTarget: %w", err) } - - if requiresAC { - return ps.propertyAccess.CountActivePropertyFieldsForTarget(groupID, targetType, targetID) - } - return ps.countActivePropertyFieldsForTarget(groupID, targetType, targetID) } func (ps *PropertyService) CountAllPropertyFieldsForTarget(rctx request.CTX, groupID, targetType, targetID string) (int64, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreCountPropertyFields(rctx, groupID); err != nil { return 0, fmt.Errorf("CountAllPropertyFieldsForTarget: %w", err) } - - if requiresAC { - return ps.propertyAccess.CountAllPropertyFieldsForTarget(groupID, targetType, targetID) - } - return ps.countAllPropertyFieldsForTarget(groupID, targetType, targetID) } func (ps *PropertyService) SearchPropertyFields(rctx request.CTX, groupID string, opts model.PropertyFieldSearchOpts) ([]*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + fields, err := ps.searchPropertyFields(groupID, opts) if err != nil { return nil, fmt.Errorf("SearchPropertyFields: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.SearchPropertyFields(callerID, groupID, opts) - } - - return ps.searchPropertyFields(groupID, opts) + return ps.runPostGetPropertyFields(rctx, fields) } -func (ps *PropertyService) UpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) +// UpdatePropertyField updates a single field. It returns the updated field and +// the IDs of fields whose dependent property values were cleared as a side +// effect (e.g. by TypeChangeValueCleanupHook on a type change). Hooks may +// cascade clears to other fields, so the slice is not necessarily limited to +// the updated field's own ID. The caller is expected to publish any +// value-cleanup WS events. +func (ps *PropertyService) UpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField) (*model.PropertyField, []string, error) { + field, err := ps.runPreUpdatePropertyField(rctx, groupID, field) if err != nil { - return nil, fmt.Errorf("UpdatePropertyField: %w", err) + return nil, nil, fmt.Errorf("UpdatePropertyField: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpdatePropertyField(callerID, groupID, field) - } - - return ps.updatePropertyField(groupID, field) + return ps.updatePropertyField(rctx, groupID, field) } -func (ps *PropertyService) UpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) (requested []*model.PropertyField, propagated []*model.PropertyField, err error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) +// UpdatePropertyFields updates a batch of fields and returns the requested set, +// any linked-property propagated fields, and the IDs of fields whose dependent +// property values were cleared as a side effect. The caller is expected to +// publish any value-cleanup WS events. +func (ps *PropertyService) UpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField) (requested []*model.PropertyField, propagated []*model.PropertyField, clearedFieldIDs []string, err error) { + fields, err = ps.runPreUpdatePropertyFields(rctx, groupID, fields) if err != nil { - return nil, nil, fmt.Errorf("UpdatePropertyFields: %w", err) + return nil, nil, nil, fmt.Errorf("UpdatePropertyFields: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpdatePropertyFields(callerID, groupID, fields) - } - - return ps.updatePropertyFields(groupID, fields) + return ps.updatePropertyFields(rctx, groupID, fields) } func (ps *PropertyService) DeletePropertyField(rctx request.CTX, groupID, id string) error { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreDeletePropertyField(rctx, groupID, id); err != nil { return fmt.Errorf("DeletePropertyField: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.DeletePropertyField(callerID, groupID, id) - } - return ps.deletePropertyField(groupID, id) } @@ -667,9 +632,9 @@ func optionsChanged(oldAttrs, newAttrs model.StringInterface) bool { return false } -// extractOptionIDs extracts the "id" field from each option in the given options value +// extractOptionIDList extracts the "id" field from each option in the given options value // using direct type assertions (no JSON marshaling). -func extractOptionIDs(options any) []string { +func extractOptionIDList(options any) []string { if options == nil { return nil } diff --git a/server/channels/app/properties/property_field_test.go b/server/channels/app/properties/property_field_test.go index afff492e788..fbebec23a40 100644 --- a/server/channels/app/properties/property_field_test.go +++ b/server/channels/app/properties/property_field_test.go @@ -13,63 +13,39 @@ import ( "github.com/stretchr/testify/require" ) -func TestRequiresAccessControlFailsClosed(t *testing.T) { - th := Setup(t) +func TestHooksOnlyScopeToManagedGroups(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) rctx := th.Context - // Use an unregistered group — this means any call to - // requiresAccessControl will fail to look up the group. - // The service must return an error rather than silently bypassing - // access control. - unregisteredGroupID := model.NewId() + // Operations on an unmanaged group should bypass the access control + // hook entirely and proceed directly to the store layer. + unmanagedGroup, err := th.service.RegisterPropertyGroup(&model.PropertyGroup{Name: "unmanaged_group", Version: model.PropertyGroupVersionV2}) + require.NoError(t, err) - t.Run("CreatePropertyField returns error when group lookup fails", func(t *testing.T) { + t.Run("CreatePropertyField on unmanaged group bypasses hooks", func(t *testing.T) { field := &model.PropertyField{ - GroupID: unregisteredGroupID, - Name: "test-field", + GroupID: unmanagedGroup.ID, + Name: "test-field-" + model.NewId(), Type: model.PropertyFieldTypeText, - TargetType: "user", + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } - _, err := th.service.CreatePropertyField(rctx, field) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") + result, err := th.service.CreatePropertyField(rctx, field) + require.NoError(t, err) + assert.NotEmpty(t, result.ID) }) - t.Run("GetPropertyField returns error when group lookup fails", func(t *testing.T) { - _, err := th.service.GetPropertyField(rctx, unregisteredGroupID, model.NewId()) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") - }) - - t.Run("GetPropertyFields returns error when group lookup fails", func(t *testing.T) { - _, err := th.service.GetPropertyFields(rctx, unregisteredGroupID, []string{model.NewId()}) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") - }) - - t.Run("UpdatePropertyField returns error when group lookup fails", func(t *testing.T) { - field := &model.PropertyField{ - ID: model.NewId(), - GroupID: unregisteredGroupID, - Name: "test-field", + t.Run("GetPropertyField on unmanaged group bypasses hooks", func(t *testing.T) { + field := th.CreatePropertyFieldDirect(t, &model.PropertyField{ + GroupID: unmanagedGroup.ID, + Name: "get-field-" + model.NewId(), Type: model.PropertyFieldTypeText, - TargetType: "user", - } - _, err := th.service.UpdatePropertyField(rctx, unregisteredGroupID, field) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") - }) - - t.Run("DeletePropertyField returns error when group lookup fails", func(t *testing.T) { - err := th.service.DeletePropertyField(rctx, unregisteredGroupID, model.NewId()) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") - }) - - t.Run("SearchPropertyFields returns error when group lookup fails", func(t *testing.T) { - _, err := th.service.SearchPropertyFields(rctx, unregisteredGroupID, model.PropertyFieldSearchOpts{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check access control") + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + }) + result, err := th.service.GetPropertyField(rctx, unmanagedGroup.ID, field.ID) + require.NoError(t, err) + assert.Equal(t, field.ID, result.ID) }) } @@ -616,18 +592,14 @@ func TestUpdatePropertyField(t *testing.T) { }, }) - // Update non-name fields (Type, Attrs) - field.Type = model.PropertyFieldTypeSelect + // Update non-name fields (Attrs only) field.Attrs = map[string]any{ - "options": []any{ - map[string]any{"name": "a"}, - map[string]any{"name": "b"}, - }, + "key": "updated", } - result, err := th.service.UpdatePropertyField(rctx, groupID, field) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, field) require.NoError(t, err) - assert.Equal(t, model.PropertyFieldTypeSelect, result.Type) + assert.Equal(t, "updated", result.Attrs["key"]) }) t.Run("updating name to non-conflicting value should succeed", func(t *testing.T) { @@ -644,7 +616,7 @@ func TestUpdatePropertyField(t *testing.T) { // Update name to non-conflicting value field.Name = "NewUniqueName" - result, err := th.service.UpdatePropertyField(rctx, groupID, field) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, field) require.NoError(t, err) assert.Equal(t, "NewUniqueName", result.Name) }) @@ -674,7 +646,7 @@ func TestUpdatePropertyField(t *testing.T) { // Try to update system-level to name that conflicts with team-level systemField.Name = "ExistingTeamProp" - result, err := th.service.UpdatePropertyField(rctx, groupID, systemField) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, systemField) require.Error(t, err) assert.Nil(t, result) appErr, ok := err.(*model.AppError) @@ -712,7 +684,7 @@ func TestUpdatePropertyField(t *testing.T) { // Update DM property to same name as regular channel property - should succeed // because DM channels have no team, so they don't conflict with team channels dmField.Name = "ChannelProp" - result, err := th.service.UpdatePropertyField(rctx, groupID, dmField) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, dmField) require.NoError(t, err) assert.Equal(t, "ChannelProp", result.Name) }) @@ -742,7 +714,7 @@ func TestUpdatePropertyField(t *testing.T) { // Try to update team-level to name that conflicts with system-level teamField.Name = "ExistingSystemProp" - result, err := th.service.UpdatePropertyField(rctx, groupID, teamField) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, teamField) require.Error(t, err) assert.Nil(t, result) appErr, ok := err.(*model.AppError) @@ -781,7 +753,7 @@ func TestUpdatePropertyField(t *testing.T) { channel2Field.TargetType = string(model.PropertyFieldTargetLevelSystem) channel2Field.TargetID = "" - result, err := th.service.UpdatePropertyField(rctx, groupID, channel2Field) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, channel2Field) require.Error(t, err) assert.Nil(t, result) appErr, ok := err.(*model.AppError) @@ -823,7 +795,7 @@ func TestUpdatePropertyField(t *testing.T) { // We only verify an error occurs without checking the specific error type. channel2Field.TargetID = channel1.Id - result, err := th.service.UpdatePropertyField(rctx, groupID, channel2Field) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, channel2Field) require.Error(t, err) assert.Nil(t, result) }) @@ -842,7 +814,7 @@ func TestUpdatePropertyField(t *testing.T) { // Update name should succeed without conflict check field.Name = "UpdatedLegacyProp" - result, err := th.service.UpdatePropertyField(rctx, groupID, field) + result, _, err := th.service.UpdatePropertyField(rctx, groupID, field) require.NoError(t, err) assert.Equal(t, "UpdatedLegacyProp", result.Name) }) @@ -860,8 +832,8 @@ func TestUpdatePropertyField(t *testing.T) { }) // Update with same name should succeed (no actual change to name) - field.Type = model.PropertyFieldTypeSelect // Change something else - result, err := th.service.UpdatePropertyField(rctx, groupID, field) + field.Attrs = map[string]any{"key": "changed"} // Change something else + result, _, err := th.service.UpdatePropertyField(rctx, groupID, field) require.NoError(t, err) assert.Equal(t, "SameName", result.Name) }) @@ -1024,7 +996,7 @@ func TestLinkedPropertyFields(t *testing.T) { }) linked.Type = model.PropertyFieldTypeText - _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) + _, _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) require.Error(t, err) appErr, ok := err.(*model.AppError) require.True(t, ok) @@ -1046,7 +1018,7 @@ func TestLinkedPropertyFields(t *testing.T) { linked.Attrs[model.PropertyFieldAttributeOptions] = []any{ map[string]any{"id": model.NewId(), "name": "Different"}, } - _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) + _, _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) require.Error(t, err) appErr, ok := err.(*model.AppError) require.True(t, ok) @@ -1066,7 +1038,7 @@ func TestLinkedPropertyFields(t *testing.T) { }) linked.Name = "NewName-" + model.NewId() - result, err := th.service.UpdatePropertyField(rctx, group.ID, linked) + result, _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) require.NoError(t, err) assert.Equal(t, linked.Name, result.Name) }) @@ -1100,7 +1072,7 @@ func TestLinkedPropertyFields(t *testing.T) { } source.Attrs[model.PropertyFieldAttributeOptions] = newOptions - result, propagated, err := th.service.UpdatePropertyFields(rctx, group.ID, []*model.PropertyField{source}) + result, propagated, _, err := th.service.UpdatePropertyFields(rctx, group.ID, []*model.PropertyField{source}) require.NoError(t, err) require.Len(t, result, 1) // only the requested source field require.Len(t, propagated, 2) // 2 linked fields @@ -1112,8 +1084,8 @@ func TestLinkedPropertyFields(t *testing.T) { require.NoError(t, err) for _, linked := range []*model.PropertyField{updatedLinked1, updatedLinked2} { - opts := extractOptionIDs(linked.Attrs[model.PropertyFieldAttributeOptions]) - expectedOpts := extractOptionIDs(newOptions) + opts := extractOptionIDList(linked.Attrs[model.PropertyFieldAttributeOptions]) + expectedOpts := extractOptionIDList(newOptions) assert.Equal(t, expectedOpts, opts) } }) @@ -1131,7 +1103,7 @@ func TestLinkedPropertyFields(t *testing.T) { }) source.Type = model.PropertyFieldTypeMultiselect - _, err := th.service.UpdatePropertyField(rctx, group.ID, source) + _, _, err := th.service.UpdatePropertyField(rctx, group.ID, source) require.Error(t, err) appErr, ok := err.(*model.AppError) require.True(t, ok) @@ -1192,14 +1164,14 @@ func TestLinkedPropertyFields(t *testing.T) { // Unlink by clearing LinkedFieldID linked.LinkedFieldID = nil - result, err := th.service.UpdatePropertyField(rctx, group.ID, linked) + result, _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) require.NoError(t, err) assert.Nil(t, result.LinkedFieldID) assert.Equal(t, source.Type, result.Type) // Verify options are preserved after unlinking - sourceOpts := extractOptionIDs(source.Attrs[model.PropertyFieldAttributeOptions]) - resultOpts := extractOptionIDs(result.Attrs[model.PropertyFieldAttributeOptions]) + sourceOpts := extractOptionIDList(source.Attrs[model.PropertyFieldAttributeOptions]) + resultOpts := extractOptionIDList(result.Attrs[model.PropertyFieldAttributeOptions]) require.NotEmpty(t, sourceOpts, "source should have options") assert.Equal(t, sourceOpts, resultOpts, "options should be preserved after unlinking") }) @@ -1251,7 +1223,7 @@ func TestLinkedPropertyFields(t *testing.T) { // Attempt to set LinkedFieldID on update — should be rejected source := createSourceField(t, "LinkAttemptSource-"+model.NewId()) regular.LinkedFieldID = &source.ID - _, err := th.service.UpdatePropertyField(rctx, group.ID, regular) + _, _, err := th.service.UpdatePropertyField(rctx, group.ID, regular) require.Error(t, err) appErr, ok := err.(*model.AppError) require.True(t, ok) @@ -1274,7 +1246,7 @@ func TestLinkedPropertyFields(t *testing.T) { // Attempt to change the link target — should be rejected linked.LinkedFieldID = &source2.ID - _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) + _, _, err := th.service.UpdatePropertyField(rctx, group.ID, linked) require.Error(t, err) appErr, ok := err.(*model.AppError) require.True(t, ok) @@ -1353,7 +1325,7 @@ func TestLinkedPropertyFields(t *testing.T) { map[string]any{"id": optCID, "name": "Option C", "color": "green"}, } - result, propagated, err := th.service.UpdatePropertyFields(rctx, group.ID, []*model.PropertyField{source}) + result, propagated, _, err := th.service.UpdatePropertyFields(rctx, group.ID, []*model.PropertyField{source}) require.NoError(t, err) require.Len(t, result, 1) // only the requested source field require.Len(t, propagated, 1) // 1 linked field @@ -1362,7 +1334,7 @@ func TestLinkedPropertyFields(t *testing.T) { updatedLinked, err := th.service.GetPropertyField(rctx, group.ID, linked.ID) require.NoError(t, err) - linkedOptIDs := extractOptionIDs(updatedLinked.Attrs[model.PropertyFieldAttributeOptions]) + linkedOptIDs := extractOptionIDList(updatedLinked.Attrs[model.PropertyFieldAttributeOptions]) assert.Equal(t, []string{optAID, optCID}, linkedOptIDs, "option B should be removed from linked field") // Verify option content (names, colors) was propagated correctly @@ -1374,12 +1346,10 @@ func TestLinkedPropertyFields(t *testing.T) { assert.Equal(t, "green", linkedOpts[1]["color"]) }) - // FIXME: remove this test once CPA is fully migrated to v2 — template - // fields should then only be created on v2 groups. - t.Run("template field creation is allowed on v1 group", func(t *testing.T) { + t.Run("template field creation is rejected on v1 group", func(t *testing.T) { v1Group := th.RegisterPropertyGroup(t, model.PropertyGroupVersionV1) - template, err := th.service.CreatePropertyField(rctx, &model.PropertyField{ + _, err := th.service.CreatePropertyField(rctx, &model.PropertyField{ GroupID: v1Group.ID, ObjectType: model.PropertyFieldObjectTypeTemplate, TargetType: string(model.PropertyFieldTargetLevelSystem), @@ -1391,8 +1361,7 @@ func TestLinkedPropertyFields(t *testing.T) { }, }, }) - require.NoError(t, err) - assert.Equal(t, model.PropertyFieldObjectTypeTemplate, template.ObjectType) + require.Error(t, err) }) t.Run("cross-group linking is rejected", func(t *testing.T) { diff --git a/server/channels/app/properties/property_value.go b/server/channels/app/properties/property_value.go index cd656d328f0..dec126b738d 100644 --- a/server/channels/app/properties/property_value.go +++ b/server/channels/app/properties/property_value.go @@ -130,23 +130,18 @@ func (ps *PropertyService) deletePropertyValuesForField(groupID, fieldID string) return ps.valueStore.DeleteForField(groupID, fieldID) } -// Public routing methods +// Public methods func (ps *PropertyService) CreatePropertyValue(rctx request.CTX, value *model.PropertyValue) (*model.PropertyValue, error) { if value == nil { return nil, fmt.Errorf("CreatePropertyValue: value cannot be nil") } - requiresAC, err := ps.requiresAccessControlForGroupID(value.GroupID) + value, err := ps.runPreCreatePropertyValue(rctx, value) if err != nil { return nil, fmt.Errorf("CreatePropertyValue: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.CreatePropertyValue(callerID, value) - } - return ps.createPropertyValue(value) } @@ -164,84 +159,71 @@ func (ps *PropertyService) CreatePropertyValues(rctx request.CTX, values []*mode } } - requiresAC, err := ps.requiresAccessControlForGroupID(values[0].GroupID) + values, err := ps.runPreCreatePropertyValues(rctx, values) if err != nil { return nil, fmt.Errorf("CreatePropertyValues: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.CreatePropertyValues(callerID, values) - } - return ps.createPropertyValues(values) } func (ps *PropertyService) GetPropertyValue(rctx request.CTX, groupID, id string) (*model.PropertyValue, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + value, err := ps.getPropertyValue(groupID, id) if err != nil { return nil, fmt.Errorf("GetPropertyValue: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.GetPropertyValue(callerID, groupID, id) - } - - return ps.getPropertyValue(groupID, id) + return ps.runPostGetPropertyValue(rctx, value) } func (ps *PropertyService) GetPropertyValues(rctx request.CTX, groupID string, ids []string) ([]*model.PropertyValue, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + values, err := ps.getPropertyValues(groupID, ids) if err != nil { return nil, fmt.Errorf("GetPropertyValues: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.GetPropertyValues(callerID, groupID, ids) - } - - return ps.getPropertyValues(groupID, ids) + return ps.runPostGetPropertyValues(rctx, values) } func (ps *PropertyService) SearchPropertyValues(rctx request.CTX, groupID string, opts model.PropertyValueSearchOpts) ([]*model.PropertyValue, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + values, err := ps.searchPropertyValues(groupID, opts) if err != nil { return nil, fmt.Errorf("SearchPropertyValues: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.SearchPropertyValues(callerID, groupID, opts) - } - - return ps.searchPropertyValues(groupID, opts) + return ps.runPostGetPropertyValues(rctx, values) } func (ps *PropertyService) UpdatePropertyValue(rctx request.CTX, groupID string, value *model.PropertyValue) (*model.PropertyValue, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) + value, err := ps.runPreUpdatePropertyValue(rctx, groupID, value) if err != nil { return nil, fmt.Errorf("UpdatePropertyValue: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpdatePropertyValue(callerID, groupID, value) - } - return ps.updatePropertyValue(groupID, value) } func (ps *PropertyService) UpdatePropertyValues(rctx request.CTX, groupID string, values []*model.PropertyValue) ([]*model.PropertyValue, error) { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { - return nil, fmt.Errorf("UpdatePropertyValues: %w", err) + if len(values) == 0 { + return values, nil } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpdatePropertyValues(callerID, groupID, values) + // Hooks gate on values[0].GroupID for batch operations, so enforce + // single-group batches at the public boundary — otherwise a mixed + // batch could silently bypass per-group hook logic (license, + // validation, access control). + for i, v := range values { + if v == nil { + return nil, fmt.Errorf("UpdatePropertyValues: nil element at index %d", i) + } + if v.GroupID != values[0].GroupID { + return nil, fmt.Errorf("UpdatePropertyValues: mixed group IDs in batch") + } + } + + values, err := ps.runPreUpdatePropertyValues(rctx, groupID, values) + if err != nil { + return nil, fmt.Errorf("UpdatePropertyValues: %w", err) } return ps.updatePropertyValues(groupID, values) @@ -252,16 +234,11 @@ func (ps *PropertyService) UpsertPropertyValue(rctx request.CTX, value *model.Pr return nil, fmt.Errorf("UpsertPropertyValue: value cannot be nil") } - requiresAC, err := ps.requiresAccessControlForGroupID(value.GroupID) + value, err := ps.runPreUpsertPropertyValue(rctx, value) if err != nil { return nil, fmt.Errorf("UpsertPropertyValue: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpsertPropertyValue(callerID, value) - } - return ps.upsertPropertyValue(value) } @@ -279,57 +256,34 @@ func (ps *PropertyService) UpsertPropertyValues(rctx request.CTX, values []*mode } } - requiresAC, err := ps.requiresAccessControlForGroupID(values[0].GroupID) + values, err := ps.runPreUpsertPropertyValues(rctx, values) if err != nil { return nil, fmt.Errorf("UpsertPropertyValues: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.UpsertPropertyValues(callerID, values) - } - return ps.upsertPropertyValues(values) } func (ps *PropertyService) DeletePropertyValue(rctx request.CTX, groupID, id string) error { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreDeletePropertyValue(rctx, groupID, id); err != nil { return fmt.Errorf("DeletePropertyValue: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.DeletePropertyValue(callerID, groupID, id) - } - return ps.deletePropertyValue(groupID, id) } func (ps *PropertyService) DeletePropertyValuesForTarget(rctx request.CTX, groupID string, targetType string, targetID string) error { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreDeletePropertyValuesForTarget(rctx, groupID, targetType, targetID); err != nil { return fmt.Errorf("DeletePropertyValuesForTarget: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.DeletePropertyValuesForTarget(callerID, groupID, targetType, targetID) - } - return ps.deletePropertyValuesForTarget(groupID, targetType, targetID) } func (ps *PropertyService) DeletePropertyValuesForField(rctx request.CTX, groupID, fieldID string) error { - requiresAC, err := ps.requiresAccessControlForGroupID(groupID) - if err != nil { + if err := ps.runPreDeletePropertyValuesForField(rctx, groupID, fieldID); err != nil { return fmt.Errorf("DeletePropertyValuesForField: %w", err) } - if requiresAC { - callerID := ps.extractCallerID(rctx) - return ps.propertyAccess.DeletePropertyValuesForField(callerID, groupID, fieldID) - } - return ps.deletePropertyValuesForField(groupID, fieldID) } diff --git a/server/channels/app/properties/service.go b/server/channels/app/properties/service.go index 50508cffcf3..0543a281a01 100644 --- a/server/channels/app/properties/service.go +++ b/server/channels/app/properties/service.go @@ -5,10 +5,8 @@ package properties import ( "errors" - "fmt" "sync" - "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store" ) @@ -21,7 +19,7 @@ type PropertyService struct { groupStore store.PropertyGroupStore fieldStore store.PropertyFieldStore valueStore store.PropertyValueStore - propertyAccess *PropertyAccessService + hooks []PropertyHook callerIDExtractor CallerIDExtractor groupCache sync.Map // name -> *model.PropertyGroup groupIDCache sync.Map // id -> *model.PropertyGroup @@ -44,7 +42,6 @@ func New(c ServiceConfig) (*PropertyService, error) { fieldStore: c.PropertyFieldStore, valueStore: c.PropertyValueStore, callerIDExtractor: c.CallerIDExtractor, - propertyAccess: nil, }, nil } @@ -55,27 +52,6 @@ func (c *ServiceConfig) validate() error { return nil } -func (ps *PropertyService) SetPropertyAccessService(pas *PropertyAccessService) { - ps.propertyAccess = pas -} - -// requiresAccessControlForGroupID checks if a group ID requires access control enforcement. -// Currently, only the CPA group requires access control, but this may change in the future. -func (ps *PropertyService) requiresAccessControlForGroupID(groupID string) (bool, error) { - group, err := ps.Group(model.CustomProfileAttributesPropertyGroupName) - if err != nil { - return false, fmt.Errorf("failed to check access control for group %q: %w", groupID, err) - } - return groupID == group.ID, nil -} - -// setPluginCheckerForTests sets the plugin checker on the underlying PropertyAccessService. -func (ps *PropertyService) setPluginCheckerForTests(pluginChecker PluginChecker) { - if ps.propertyAccess != nil { - ps.propertyAccess.setPluginCheckerForTests(pluginChecker) - } -} - // extractCallerID gets the caller ID from a request context using the configured extractor. func (ps *PropertyService) extractCallerID(rctx request.CTX) string { if ps.callerIDExtractor == nil || rctx == nil { diff --git a/server/channels/app/properties/type_change_value_cleanup.go b/server/channels/app/properties/type_change_value_cleanup.go new file mode 100644 index 00000000000..65957f57e43 --- /dev/null +++ b/server/channels/app/properties/type_change_value_cleanup.go @@ -0,0 +1,66 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +// TypeChangeValueCleanupHook deletes a field's dependent property values when +// the field's Type changes on update. The Type column is part of the schema +// contract for stored values (e.g. select-option IDs are only valid against a +// matching select field), so leaving values behind across a type change leaves +// the field functionally broken until callers manually reset the values. +// +// The hook runs in PostUpdatePropertyFields. Earlier hooks +// (linked-property checks at the store layer) already reject the type-change +// cases that would corrupt linked state, so by the time this hook runs the +// only remaining type changes are on standalone fields where cleanup is the +// expected behavior. Cleanup failures are logged and skipped — the field +// update is not rolled back — to keep the operation atomic from the caller's +// perspective. +type TypeChangeValueCleanupHook struct { + BasePropertyHook + propertyService *PropertyService +} + +var _ PropertyHook = (*TypeChangeValueCleanupHook)(nil) + +// NewTypeChangeValueCleanupHook constructs the hook. The PropertyService +// reference is used to delete dependent values via the unexported +// deletePropertyValuesForField path so the hook does not re-enter the public +// hook chain (which would deadlock on its own pre-hook gating). +func NewTypeChangeValueCleanupHook(ps *PropertyService) *TypeChangeValueCleanupHook { + return &TypeChangeValueCleanupHook{propertyService: ps} +} + +// PostUpdatePropertyFields returns the IDs of fields whose dependent values +// were cleared. The caller publishes the corresponding WS events. Linked- +// property propagation cannot trigger a type change (blocked upstream), so +// the propagated bucket is passed through unchanged. +func (h *TypeChangeValueCleanupHook) PostUpdatePropertyFields(rctx request.CTX, groupID string, prev, requested, propagated []*model.PropertyField) ([]*model.PropertyField, []*model.PropertyField, []string, error) { + var cleared []string + for i, u := range requested { + if i >= len(prev) || prev[i] == nil || u == nil { + continue + } + if prev[i].Type == u.Type { + continue + } + if err := h.propertyService.deletePropertyValuesForField(groupID, u.ID); err != nil { + rctx.Logger().Error("type-change value cleanup failed", + mlog.String("group_id", groupID), + mlog.String("field_id", u.ID), + mlog.String("from_type", string(prev[i].Type)), + mlog.String("to_type", string(u.Type)), + mlog.Err(err), + ) + continue + } + cleared = append(cleared, u.ID) + } + return requested, propagated, cleared, nil +} diff --git a/server/channels/app/properties/type_change_value_cleanup_test.go b/server/channels/app/properties/type_change_value_cleanup_test.go new file mode 100644 index 00000000000..4a121efdc63 --- /dev/null +++ b/server/channels/app/properties/type_change_value_cleanup_test.go @@ -0,0 +1,216 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package properties + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestTypeChangeValueCleanupHook verifies the post-update hook detects a Type +// change and deletes the field's dependent property values, surfacing the +// cleared field IDs to the caller. +func TestTypeChangeValueCleanupHook(t *testing.T) { + th := Setup(t).RegisterCPAPropertyGroup(t) + th.service.AddHook(NewTypeChangeValueCleanupHook(th.service)) + + t.Run("type change deletes values and reports cleared field id", func(t *testing.T) { + // Create a select field with two options. + optionAID := model.NewId() + optionBID := model.NewId() + field := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "select-field-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": optionAID, "name": "Option A"}, + {"id": optionBID, "name": "Option B"}, + }, + }, + } + created, err := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, err) + + // Seed a value referencing one of the options. + userID := model.NewId() + raw, err := json.Marshal(optionAID) + require.NoError(t, err) + _, err = th.service.UpsertPropertyValue(th.Context, &model.PropertyValue{ + GroupID: th.CPAGroupID, + FieldID: created.ID, + TargetID: userID, + TargetType: model.PropertyValueTargetTypeUser, + Value: raw, + }) + require.NoError(t, err) + + // Confirm the value exists pre-patch. + preValues, err := th.service.SearchPropertyValues(th.Context, th.CPAGroupID, model.PropertyValueSearchOpts{ + FieldID: created.ID, + PerPage: 10, + }) + require.NoError(t, err) + require.Len(t, preValues, 1) + + // Patch to type=text. AccessControlAttributeValidationHook strips the now-invalid + // options attr; TypeChangeValueCleanupHook deletes the dependent value. + created.Type = model.PropertyFieldTypeText + _, clearedIDs, err := th.service.UpdatePropertyField(th.Context, th.CPAGroupID, created) + require.NoError(t, err) + assert.Equal(t, []string{created.ID}, clearedIDs, "expected post-hook to report the type-changed field as cleared") + + // Confirm the value is gone. + postValues, err := th.service.SearchPropertyValues(th.Context, th.CPAGroupID, model.PropertyValueSearchOpts{ + FieldID: created.ID, + PerPage: 10, + }) + require.NoError(t, err) + assert.Empty(t, postValues, "expected dependent values to be cleared") + }) + + t.Run("multiselect type change deletes values and reports cleared field id", func(t *testing.T) { + // Same shape as the select case above, but for multiselect. + optionAID := model.NewId() + optionBID := model.NewId() + field := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "multiselect-field-" + model.NewId(), + Type: model.PropertyFieldTypeMultiselect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": optionAID, "name": "Option A"}, + {"id": optionBID, "name": "Option B"}, + }, + }, + } + created, err := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, err) + + // Multiselect value is a JSON array of option IDs. + userID := model.NewId() + raw, err := json.Marshal([]string{optionAID, optionBID}) + require.NoError(t, err) + _, err = th.service.UpsertPropertyValue(th.Context, &model.PropertyValue{ + GroupID: th.CPAGroupID, + FieldID: created.ID, + TargetID: userID, + TargetType: model.PropertyValueTargetTypeUser, + Value: raw, + }) + require.NoError(t, err) + + preValues, err := th.service.SearchPropertyValues(th.Context, th.CPAGroupID, model.PropertyValueSearchOpts{ + FieldID: created.ID, + PerPage: 10, + }) + require.NoError(t, err) + require.Len(t, preValues, 1) + + created.Type = model.PropertyFieldTypeText + _, clearedIDs, err := th.service.UpdatePropertyField(th.Context, th.CPAGroupID, created) + require.NoError(t, err) + assert.Equal(t, []string{created.ID}, clearedIDs, "expected post-hook to report the type-changed field as cleared") + + postValues, err := th.service.SearchPropertyValues(th.Context, th.CPAGroupID, model.PropertyValueSearchOpts{ + FieldID: created.ID, + PerPage: 10, + }) + require.NoError(t, err) + assert.Empty(t, postValues, "expected dependent values to be cleared") + }) + + t.Run("same-type patch is a no-op for cleanup", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "text-field-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + created, err := th.service.CreatePropertyField(th.Context, field) + require.NoError(t, err) + + raw, err := json.Marshal("hello") + require.NoError(t, err) + _, err = th.service.UpsertPropertyValue(th.Context, &model.PropertyValue{ + GroupID: th.CPAGroupID, + FieldID: created.ID, + TargetID: model.NewId(), + TargetType: model.PropertyValueTargetTypeUser, + Value: raw, + }) + require.NoError(t, err) + + // Rename only — no Type change. + created.Name = "text-field-renamed-" + model.NewId() + _, clearedIDs, err := th.service.UpdatePropertyField(th.Context, th.CPAGroupID, created) + require.NoError(t, err) + assert.Empty(t, clearedIDs, "rename without type change must not clear values") + + values, err := th.service.SearchPropertyValues(th.Context, th.CPAGroupID, model.PropertyValueSearchOpts{ + FieldID: created.ID, + PerPage: 10, + }) + require.NoError(t, err) + assert.Len(t, values, 1, "value must survive a rename") + }) + + t.Run("plural batch reports cleared ids per affected field", func(t *testing.T) { + // Field 1: select with a value, will be patched to text → cleanup expected. + optID := model.NewId() + f1 := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "batch-select-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": optID, "name": "Only Option"}, + }, + }, + } + created1, err := th.service.CreatePropertyField(th.Context, f1) + require.NoError(t, err) + + // Field 2: text, will be renamed only → no cleanup expected. + f2 := &model.PropertyField{ + GroupID: th.CPAGroupID, + Name: "batch-text-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + created2, err := th.service.CreatePropertyField(th.Context, f2) + require.NoError(t, err) + + raw, err := json.Marshal(optID) + require.NoError(t, err) + _, err = th.service.UpsertPropertyValue(th.Context, &model.PropertyValue{ + GroupID: th.CPAGroupID, + FieldID: created1.ID, + TargetID: model.NewId(), + TargetType: model.PropertyValueTargetTypeUser, + Value: raw, + }) + require.NoError(t, err) + + // Mutate both: f1 changes Type, f2 changes Name only. + created1.Type = model.PropertyFieldTypeText + created2.Name = "batch-text-renamed-" + model.NewId() + + _, _, clearedIDs, err := th.service.UpdatePropertyFields(th.Context, th.CPAGroupID, []*model.PropertyField{created1, created2}) + require.NoError(t, err) + assert.Equal(t, []string{created1.ID}, clearedIDs, "only the type-changed field should be in clearedIDs") + }) +} diff --git a/server/channels/app/property_errors.go b/server/channels/app/property_errors.go new file mode 100644 index 00000000000..3a1a8b3413f --- /dev/null +++ b/server/channels/app/property_errors.go @@ -0,0 +1,77 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "errors" + "net/http" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/app/properties" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +// mapPropertyServiceError translates known errors from the property service / +// PropertyHook chain — package sentinels (properties.Err*) and store-layer +// errors (*store.ErrNotFound, *store.ErrConflict, *store.ErrResultsMismatch) — +// into HTTP-shaped AppErrors. Returns nil if err is not recognised and does +// not wrap an AppError; callers should fall back to wrapping with their own +// default 500 in that case. +// +// Sentinel matches take priority over a wrapped AppError so that hook code +// wrapping an inner AppError with a sentinel still drives the mapping. +// +// User-facing DetailedError is left empty on access-control rejections to +// avoid leaking field IDs, plugin IDs, and sync source names. The full +// chain remains available for operator logs via Wrap(err). +func mapPropertyServiceError(where string, err error) *model.AppError { + if err == nil { + return nil + } + + switch { + case errors.Is(err, properties.ErrAccessDenied): + return model.NewAppError(where, "app.property.access_denied.app_error", nil, "", http.StatusForbidden).Wrap(err) + case errors.Is(err, properties.ErrSyncLocked): + return model.NewAppError(where, "app.property.sync_lock.app_error", nil, "", http.StatusForbidden).Wrap(err) + case errors.Is(err, properties.ErrInvalidAccessMode): + return model.NewAppError(where, "app.property.invalid_access_mode.app_error", nil, err.Error(), http.StatusBadRequest).Wrap(err) + case errors.Is(err, properties.ErrFieldLimitReached): + return model.NewAppError(where, "app.property_field.create.limit_reached.app_error", nil, err.Error(), http.StatusUnprocessableEntity).Wrap(err) + case errors.Is(err, properties.ErrGroupFieldLimitReached): + return model.NewAppError(where, "app.property_field.create.group_limit_reached.app_error", nil, err.Error(), http.StatusUnprocessableEntity).Wrap(err) + case errors.Is(err, properties.ErrLicenseRequired): + return model.NewAppError(where, "app.property.license_error", nil, "", http.StatusForbidden).Wrap(err) + case errors.Is(err, properties.ErrInvalidFieldAttrs): + return model.NewAppError(where, "app.property_field.invalid_attrs.app_error", nil, err.Error(), http.StatusBadRequest).Wrap(err) + case errors.Is(err, properties.ErrInvalidValue): + return model.NewAppError(where, "app.property_value.validate.app_error", nil, err.Error(), http.StatusBadRequest).Wrap(err) + case errors.Is(err, properties.ErrAdminRequired): + return model.NewAppError(where, "app.property_field.managed_admin.permission.app_error", nil, "", http.StatusForbidden).Wrap(err) + case errors.Is(err, properties.ErrFieldNotFound): + return model.NewAppError(where, "app.property_field.not_found.app_error", nil, "", http.StatusNotFound).Wrap(err) + } + + var conflictErr *store.ErrConflict + if errors.As(err, &conflictErr) { + return model.NewAppError(where, "app.property_field.update.conflict.app_error", nil, "concurrent modification detected; please retry", http.StatusConflict).Wrap(err) + } + + var notFoundErr *store.ErrNotFound + if errors.As(err, ¬FoundErr) { + return model.NewAppError(where, "app.property.not_found.app_error", nil, "", http.StatusNotFound).Wrap(err) + } + + var resultsMismatchErr *store.ErrResultsMismatch + if errors.As(err, &resultsMismatchErr) { + return model.NewAppError(where, "app.property.not_found.app_error", nil, "", http.StatusNotFound).Wrap(err) + } + + var appErr *model.AppError + if errors.As(err, &appErr) { + return appErr + } + + return nil +} diff --git a/server/channels/app/property_errors_test.go b/server/channels/app/property_errors_test.go new file mode 100644 index 00000000000..83bbbe0aac6 --- /dev/null +++ b/server/channels/app/property_errors_test.go @@ -0,0 +1,146 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/app/properties" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMapPropertyServiceError(t *testing.T) { + t.Run("nil err returns nil", func(t *testing.T) { + require.Nil(t, mapPropertyServiceError("Where", nil)) + }) + + t.Run("unknown err returns nil so caller can 500-wrap", func(t *testing.T) { + require.Nil(t, mapPropertyServiceError("Where", errors.New("db connection lost"))) + }) + + t.Run("unwrapped AppError is returned as-is via fallback", func(t *testing.T) { + orig := model.NewAppError("SomeSource", "some.id", nil, "detail", http.StatusTeapot) + got := mapPropertyServiceError("Where", orig) + require.NotNil(t, got) + assert.Same(t, orig, got) + }) + + sentinelCases := []struct { + name string + sentinel error + expectedID string + expectedStatus int + expectDetail bool + }{ + { + name: "access denied", + sentinel: properties.ErrAccessDenied, + expectedID: "app.property.access_denied.app_error", + expectedStatus: http.StatusForbidden, + expectDetail: false, + }, + { + name: "sync locked", + sentinel: properties.ErrSyncLocked, + expectedID: "app.property.sync_lock.app_error", + expectedStatus: http.StatusForbidden, + expectDetail: false, + }, + { + name: "invalid access mode", + sentinel: properties.ErrInvalidAccessMode, + expectedID: "app.property.invalid_access_mode.app_error", + expectedStatus: http.StatusBadRequest, + expectDetail: true, + }, + { + name: "field limit reached", + sentinel: properties.ErrFieldLimitReached, + expectedID: "app.property_field.create.limit_reached.app_error", + expectedStatus: http.StatusUnprocessableEntity, + expectDetail: true, + }, + { + name: "group field limit reached", + sentinel: properties.ErrGroupFieldLimitReached, + expectedID: "app.property_field.create.group_limit_reached.app_error", + expectedStatus: http.StatusUnprocessableEntity, + expectDetail: true, + }, + { + name: "license required", + sentinel: properties.ErrLicenseRequired, + expectedID: "app.property.license_error", + expectedStatus: http.StatusForbidden, + expectDetail: false, + }, + { + name: "invalid field attrs", + sentinel: properties.ErrInvalidFieldAttrs, + expectedID: "app.property_field.invalid_attrs.app_error", + expectedStatus: http.StatusBadRequest, + expectDetail: true, + }, + { + name: "invalid value", + sentinel: properties.ErrInvalidValue, + expectedID: "app.property_value.validate.app_error", + expectedStatus: http.StatusBadRequest, + expectDetail: true, + }, + { + name: "admin required", + sentinel: properties.ErrAdminRequired, + expectedID: "app.property_field.managed_admin.permission.app_error", + expectedStatus: http.StatusForbidden, + expectDetail: false, + }, + { + name: "field not found", + sentinel: properties.ErrFieldNotFound, + expectedID: "app.property_field.not_found.app_error", + expectedStatus: http.StatusNotFound, + expectDetail: false, + }, + } + + for _, tc := range sentinelCases { + t.Run("direct sentinel: "+tc.name, func(t *testing.T) { + got := mapPropertyServiceError("Where", tc.sentinel) + require.NotNil(t, got) + assert.Equal(t, tc.expectedID, got.Id) + assert.Equal(t, tc.expectedStatus, got.StatusCode) + assert.Equal(t, "Where", got.Where) + if tc.expectDetail { + assert.NotEmpty(t, got.DetailedError, "sentinel %s should carry operator-facing detail", tc.name) + } else { + assert.Empty(t, got.DetailedError, "sentinel %s should redact detail to avoid leaking internal identifiers", tc.name) + } + }) + + t.Run("wrapped sentinel detected through chain: "+tc.name, func(t *testing.T) { + wrapped := fmt.Errorf("outer context: %w", fmt.Errorf("inner context: %w", tc.sentinel)) + got := mapPropertyServiceError("Where", wrapped) + require.NotNil(t, got) + assert.Equal(t, tc.expectedID, got.Id) + assert.Equal(t, tc.expectedStatus, got.StatusCode) + }) + } + + t.Run("sentinel priority over wrapped AppError", func(t *testing.T) { + // A hook that wraps an AppError with a sentinel should be mapped by + // the sentinel, not by the embedded AppError. + inner := model.NewAppError("OldPath", "old.id", nil, "old detail", http.StatusTeapot) + wrapped := fmt.Errorf("authz denied: %w: %w", properties.ErrAccessDenied, inner) + got := mapPropertyServiceError("Where", wrapped) + require.NotNil(t, got) + assert.Equal(t, "app.property.access_denied.app_error", got.Id) + assert.Equal(t, http.StatusForbidden, got.StatusCode) + }) +} diff --git a/server/channels/app/property_field.go b/server/channels/app/property_field.go index 2634db9181c..2d749194857 100644 --- a/server/channels/app/property_field.go +++ b/server/channels/app/property_field.go @@ -5,15 +5,27 @@ package app import ( "encoding/json" - "errors" "net/http" + "reflect" + "strings" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" - "github.com/mattermost/mattermost/server/v8/channels/store" ) +// propertyFieldOptionsEqual reports whether two values from +// PropertyField.Attrs[options] are equivalent. Used to detect a no-op options +// patch on a linked field — see UpdatePropertyFields' linked-field invariants. +// Both nil/zero forms compare equal; otherwise reflect.DeepEqual handles the +// nested map/slice shape produced by JSON unmarshalling. +func propertyFieldOptionsEqual(a, b any) bool { + if a == nil && b == nil { + return true + } + return reflect.DeepEqual(a, b) +} + func propertyFieldBroadcastParams(rctx request.CTX, field *model.PropertyField) (teamID, channelID string, ok bool) { switch field.TargetType { case "team": @@ -57,6 +69,10 @@ func (a *App) CreatePropertyField(rctx request.CTX, field *model.PropertyField, return nil, model.NewAppError("CreatePropertyField", "app.property_field.invalid_input.app_error", nil, "property field is required", http.StatusBadRequest) } + // Intrinsic invariants (apply to every caller — HTTP, plugin, internal). + CanonicalizeSystemObjectField(field) + field.Name = strings.TrimSpace(field.Name) + if !bypassProtectedCheck && field.Protected { return nil, model.NewAppError( "CreatePropertyField", @@ -69,8 +85,7 @@ func (a *App) CreatePropertyField(rctx request.CTX, field *model.PropertyField, createdField, err := a.Srv().propertyService.CreatePropertyField(rctx, field) if err != nil { - var appErr *model.AppError - if errors.As(err, &appErr) { + if appErr := mapPropertyServiceError("CreatePropertyField", err); appErr != nil { return nil, appErr } return nil, model.NewAppError("CreatePropertyField", "app.property_field.create.app_error", nil, "", http.StatusInternalServerError).Wrap(err) @@ -85,6 +100,9 @@ func (a *App) CreatePropertyField(rctx request.CTX, field *model.PropertyField, func (a *App) GetPropertyField(rctx request.CTX, groupID, fieldID string) (*model.PropertyField, *model.AppError) { field, err := a.Srv().propertyService.GetPropertyField(rctx, groupID, fieldID) if err != nil { + if appErr := mapPropertyServiceError("GetPropertyField", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("GetPropertyField", "app.property_field.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return field, nil @@ -94,9 +112,8 @@ func (a *App) GetPropertyField(rctx request.CTX, groupID, fieldID string) (*mode func (a *App) GetPropertyFields(rctx request.CTX, groupID string, ids []string) ([]*model.PropertyField, *model.AppError) { fields, err := a.Srv().propertyService.GetPropertyFields(rctx, groupID, ids) if err != nil { - var resultsMismatchErr *store.ErrResultsMismatch - if errors.As(err, &resultsMismatchErr) { - return nil, model.NewAppError("GetPropertyFields", "app.property_field.get_many.fields_not_found.app_error", nil, "", http.StatusBadRequest).Wrap(err) + if appErr := mapPropertyServiceError("GetPropertyFields", err); appErr != nil { + return nil, appErr } return nil, model.NewAppError("GetPropertyFields", "app.property_field.get_many.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -107,6 +124,9 @@ func (a *App) GetPropertyFields(rctx request.CTX, groupID string, ids []string) func (a *App) GetPropertyFieldByName(rctx request.CTX, groupID, targetID, name string) (*model.PropertyField, *model.AppError) { field, err := a.Srv().propertyService.GetPropertyFieldByName(rctx, groupID, targetID, name) if err != nil { + if appErr := mapPropertyServiceError("GetPropertyFieldByName", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("GetPropertyFieldByName", "app.property_field.get_by_name.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return field, nil @@ -116,6 +136,9 @@ func (a *App) GetPropertyFieldByName(rctx request.CTX, groupID, targetID, name s func (a *App) SearchPropertyFields(rctx request.CTX, groupID string, opts model.PropertyFieldSearchOpts) ([]*model.PropertyField, *model.AppError) { fields, err := a.Srv().propertyService.SearchPropertyFields(rctx, groupID, opts) if err != nil { + if appErr := mapPropertyServiceError("SearchPropertyFields", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("SearchPropertyFields", "app.property_field.search.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return fields, nil @@ -132,6 +155,9 @@ func (a *App) CountPropertyFieldsForGroup(rctx request.CTX, groupID string, incl } if err != nil { + if appErr := mapPropertyServiceError("CountPropertyFieldsForGroup", err); appErr != nil { + return 0, appErr + } return 0, model.NewAppError("CountPropertyFieldsForGroup", "app.property_field.count_for_group.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return count, nil @@ -148,64 +174,140 @@ func (a *App) CountPropertyFieldsForTarget(rctx request.CTX, groupID, targetType } if err != nil { + if appErr := mapPropertyServiceError("CountPropertyFieldsForTarget", err); appErr != nil { + return 0, appErr + } return 0, model.NewAppError("CountPropertyFieldsForTarget", "app.property_field.count_for_target.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return count, nil } -// UpdatePropertyField updates an existing property field. -func (a *App) UpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField, bypassProtectedCheck bool, connectionID string) (*model.PropertyField, *model.AppError) { - fields, err := a.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field}, bypassProtectedCheck, connectionID) +// UpdatePropertyField updates an existing property field. The second return +// value lists the IDs of fields whose dependent property values were cleared +// as a side effect (e.g. by TypeChangeValueCleanupHook on a type change). +// Hooks may cascade clears to other fields, so the slice is not necessarily +// limited to the updated field's own ID. +func (a *App) UpdatePropertyField(rctx request.CTX, groupID string, field *model.PropertyField, bypassProtectedCheck bool, connectionID string) (*model.PropertyField, []string, *model.AppError) { + fields, clearedIDs, err := a.UpdatePropertyFields(rctx, groupID, []*model.PropertyField{field}, bypassProtectedCheck, connectionID) if err != nil { - return nil, err + return nil, nil, err } - return fields[0], nil + return fields[0], clearedIDs, nil } -// UpdatePropertyFields updates multiple property fields. -func (a *App) UpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField, bypassProtectedCheck bool, connectionID string) ([]*model.PropertyField, *model.AppError) { +// UpdatePropertyFields updates multiple property fields. The second return +// value lists the IDs of fields whose dependent property values were cleared +// as a side effect. +func (a *App) UpdatePropertyFields(rctx request.CTX, groupID string, fields []*model.PropertyField, bypassProtectedCheck bool, connectionID string) ([]*model.PropertyField, []string, *model.AppError) { if len(fields) == 0 { - return nil, model.NewAppError("UpdatePropertyFields", "app.property_field.invalid_input.app_error", nil, "property fields are required", http.StatusBadRequest) + return nil, nil, model.NewAppError("UpdatePropertyFields", "app.property_field.invalid_input.app_error", nil, "property fields are required", http.StatusBadRequest) } - if !bypassProtectedCheck { - ids := make([]string, len(fields)) - for i, f := range fields { - ids[i] = f.ID + // Intrinsic invariants — apply to every caller (HTTP, plugin, internal). + // Service returns DB-order, not input-order, so we'll build a lookup map + // keyed by ID below; collect IDs in this same pass. + ids := make([]string, len(fields)) + for i, f := range fields { + f.Name = strings.TrimSpace(f.Name) + ids[i] = f.ID + } + + // Load existing fields once. Used for: protected-check (gated by + // bypassProtectedCheck), PSAv1 reject (always-on), linked-field diff + // invariants (always-on). + + existingFields, err := a.Srv().propertyService.GetPropertyFields(rctx, groupID, ids) + if err != nil { + if appErr := mapPropertyServiceError("UpdatePropertyFields", err); appErr != nil { + return nil, nil, appErr + } + return nil, nil, model.NewAppError("UpdatePropertyFields", "app.property_field.update.get_existing.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + existingByID := make(map[string]*model.PropertyField, len(existingFields)) + for _, ex := range existingFields { + existingByID[ex.ID] = ex + } + + for _, f := range fields { + existing, ok := existingByID[f.ID] + if !ok { + // Service-level GetPropertyFields returns an ErrResultsMismatch when + // any input ID is missing, so this branch is defensive. + continue } - existingFields, err := a.Srv().propertyService.GetPropertyFields(rctx, groupID, ids) - if err != nil { - return nil, model.NewAppError("UpdatePropertyFields", "app.property_field.update.get_existing.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - } + // Linked-field diff invariants. "Linked" = LinkedFieldID != nil && + // *LinkedFieldID != "". Unlink (nil or "") is always allowed when + // existing was linked. + existingLinked := existing.LinkedFieldID != nil && *existing.LinkedFieldID != "" + incomingLinked := f.LinkedFieldID != nil && *f.LinkedFieldID != "" - for _, existing := range existingFields { - if existing.Protected { - return nil, model.NewAppError( + if existingLinked { + if f.Type != existing.Type { + return nil, nil, model.NewAppError( "UpdatePropertyFields", - "app.property_field.update.protected.app_error", + "app.property_field.update.linked_type_change.app_error", map[string]any{"FieldID": existing.ID}, - "cannot update protected field", - http.StatusForbidden, + "cannot modify type of a linked field", + http.StatusBadRequest, ) } + // Compare the options portion of Attrs. + var existingOpts, incomingOpts any + if existing.Attrs != nil { + existingOpts = existing.Attrs[model.PropertyFieldAttributeOptions] + } + if f.Attrs != nil { + incomingOpts = f.Attrs[model.PropertyFieldAttributeOptions] + } + if !propertyFieldOptionsEqual(existingOpts, incomingOpts) { + return nil, nil, model.NewAppError( + "UpdatePropertyFields", + "app.property_field.update.linked_options_change.app_error", + map[string]any{"FieldID": existing.ID}, + "cannot modify options of a linked field", + http.StatusBadRequest, + ) + } + if incomingLinked && *f.LinkedFieldID != *existing.LinkedFieldID { + return nil, nil, model.NewAppError( + "UpdatePropertyFields", + "app.property_field.update.cannot_change_link_target.app_error", + map[string]any{"FieldID": existing.ID}, + "cannot change link target", + http.StatusBadRequest, + ) + } + } else if incomingLinked { + return nil, nil, model.NewAppError( + "UpdatePropertyFields", + "app.property_field.update.cannot_link_existing.app_error", + map[string]any{"FieldID": existing.ID}, + "linked_field_id can only be set at creation time", + http.StatusBadRequest, + ) + } + + // Protected-check is the only invariant gated on the caller's opt-out. + if !bypassProtectedCheck && existing.Protected { + return nil, nil, model.NewAppError( + "UpdatePropertyFields", + "app.property_field.update.protected.app_error", + map[string]any{"FieldID": existing.ID}, + "cannot update protected field", + http.StatusForbidden, + ) } } - updated, propagated, err := a.Srv().propertyService.UpdatePropertyFields(rctx, groupID, fields) + updated, propagated, clearedFieldIDs, err := a.Srv().propertyService.UpdatePropertyFields(rctx, groupID, fields) if err != nil { - var appErr *model.AppError - if errors.As(err, &appErr) { - return nil, appErr + if appErr := mapPropertyServiceError("UpdatePropertyFields", err); appErr != nil { + return nil, nil, appErr } - - var conflictErr *store.ErrConflict - if errors.As(err, &conflictErr) { - return nil, model.NewAppError("UpdatePropertyFields", "app.property_field.update.conflict.app_error", nil, "concurrent modification detected; please retry", http.StatusConflict).Wrap(err) - } - - return nil, model.NewAppError("UpdatePropertyFields", "app.property_field.update.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + return nil, nil, model.NewAppError("UpdatePropertyFields", "app.property_field.update.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } // Broadcast websocket events for both requested and propagated fields @@ -216,13 +318,27 @@ func (a *App) UpdatePropertyFields(rctx request.CTX, groupID string, fields []*m a.publishPropertyFieldEvent(rctx, model.WebsocketEventPropertyFieldUpdated, field, "") } - return updated, nil + // For each field whose dependent values were cleared as a side effect of + // the update (e.g. a type change handled by TypeChangeValueCleanupHook), + // publish the generic property_values_updated event so subscribers refresh + // their local caches. Mirrors App.DeletePropertyValuesForField's wire shape. + for _, fieldID := range clearedFieldIDs { + message := model.NewWebSocketEvent(model.WebsocketEventPropertyValuesUpdated, "", "", "", nil, "") + message.Add("field_id", fieldID) + message.Add("values", "[]") + a.Publish(message) + } + + return updated, clearedFieldIDs, nil } // DeletePropertyField deletes a property field. func (a *App) DeletePropertyField(rctx request.CTX, groupID, id string, bypassProtectedCheck bool, connectionID string) *model.AppError { existing, err := a.Srv().propertyService.GetPropertyField(rctx, groupID, id) if err != nil { + if appErr := mapPropertyServiceError("DeletePropertyField", err); appErr != nil { + return appErr + } return model.NewAppError("DeletePropertyField", "app.property_field.delete.get_existing.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } if existing == nil { @@ -240,8 +356,7 @@ func (a *App) DeletePropertyField(rctx request.CTX, groupID, id string, bypassPr } if err := a.Srv().propertyService.DeletePropertyField(rctx, groupID, id); err != nil { - var appErr *model.AppError - if errors.As(err, &appErr) { + if appErr := mapPropertyServiceError("DeletePropertyField", err); appErr != nil { return appErr } return model.NewAppError("DeletePropertyField", "app.property_field.delete.app_error", nil, "", http.StatusInternalServerError).Wrap(err) diff --git a/server/channels/app/property_field_helpers.go b/server/channels/app/property_field_helpers.go new file mode 100644 index 00000000000..235b954661a --- /dev/null +++ b/server/channels/app/property_field_helpers.go @@ -0,0 +1,43 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +// DefaultPropertyFieldPermissionLevel returns the permission level that +// nil-fill / non-admin-pin should use for this field. Templates and system +// fields default to sysadmin (templates define the schema linked fields +// inherit; system fields attach to the Mattermost instance and only an +// administrator should write them). Other object types default to member. +func DefaultPropertyFieldPermissionLevel(field *model.PropertyField) model.PermissionLevel { + if field.ObjectType == model.PropertyFieldObjectTypeTemplate || + field.ObjectType == model.PropertyFieldObjectTypeSystem { + return model.PermissionLevelSysadmin + } + return model.PermissionLevelMember +} + +// CanonicalizeSystemObjectField forces a system-object field to its only +// valid shape: TargetType="system", TargetID="", and all three Permission* +// pinned to sysadmin. A system field's TargetType makes member-level scope +// checks resolve to "any authenticated user" (see hasPropertyFieldScopeAccess +// in app/authorization.go), so honouring a member-level permission would +// expose the field's definition, options, and values to every logged-in user. +// +// Idempotent. Safe to call from both the API handler (before scope check) +// and from inside App.CreatePropertyField (defense in depth, covers +// plugin/internal callers). +func CanonicalizeSystemObjectField(field *model.PropertyField) { + if field == nil || field.ObjectType != model.PropertyFieldObjectTypeSystem { + return + } + field.TargetType = string(model.PropertyFieldTargetLevelSystem) + field.TargetID = "" + sysadmin := model.PermissionLevelSysadmin + field.PermissionField = &sysadmin + field.PermissionValues = &sysadmin + field.PermissionOptions = &sysadmin +} diff --git a/server/channels/app/property_field_helpers_test.go b/server/channels/app/property_field_helpers_test.go new file mode 100644 index 00000000000..0f07e0839c5 --- /dev/null +++ b/server/channels/app/property_field_helpers_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" +) + +func TestDefaultPropertyFieldPermissionLevel(t *testing.T) { + t.Parallel() + + t.Run("template defaults to sysadmin", func(t *testing.T) { + f := &model.PropertyField{ObjectType: model.PropertyFieldObjectTypeTemplate} + assert.Equal(t, model.PermissionLevelSysadmin, DefaultPropertyFieldPermissionLevel(f)) + }) + + t.Run("system defaults to sysadmin", func(t *testing.T) { + f := &model.PropertyField{ObjectType: model.PropertyFieldObjectTypeSystem} + assert.Equal(t, model.PermissionLevelSysadmin, DefaultPropertyFieldPermissionLevel(f)) + }) + + t.Run("user defaults to member", func(t *testing.T) { + f := &model.PropertyField{ObjectType: model.PropertyFieldObjectTypeUser} + assert.Equal(t, model.PermissionLevelMember, DefaultPropertyFieldPermissionLevel(f)) + }) + + t.Run("channel defaults to member", func(t *testing.T) { + f := &model.PropertyField{ObjectType: model.PropertyFieldObjectTypeChannel} + assert.Equal(t, model.PermissionLevelMember, DefaultPropertyFieldPermissionLevel(f)) + }) + + t.Run("post defaults to member", func(t *testing.T) { + f := &model.PropertyField{ObjectType: model.PropertyFieldObjectTypePost} + assert.Equal(t, model.PermissionLevelMember, DefaultPropertyFieldPermissionLevel(f)) + }) +} + +func TestCanonicalizeSystemObjectField(t *testing.T) { + t.Parallel() + + t.Run("system object: forces TargetType=system, empty TargetID, all permissions sysadmin", func(t *testing.T) { + member := model.PermissionLevelMember + f := &model.PropertyField{ + ObjectType: model.PropertyFieldObjectTypeSystem, + TargetType: "channel", + TargetID: "ch1", + PermissionField: &member, + PermissionValues: &member, + PermissionOptions: &member, + } + CanonicalizeSystemObjectField(f) + assert.Equal(t, string(model.PropertyFieldTargetLevelSystem), f.TargetType) + assert.Empty(t, f.TargetID) + assert.NotNil(t, f.PermissionField) + assert.Equal(t, model.PermissionLevelSysadmin, *f.PermissionField) + assert.NotNil(t, f.PermissionValues) + assert.Equal(t, model.PermissionLevelSysadmin, *f.PermissionValues) + assert.NotNil(t, f.PermissionOptions) + assert.Equal(t, model.PermissionLevelSysadmin, *f.PermissionOptions) + }) + + t.Run("non-system object: untouched", func(t *testing.T) { + member := model.PermissionLevelMember + f := &model.PropertyField{ + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: "channel", + TargetID: "ch1", + PermissionField: &member, + PermissionValues: &member, + PermissionOptions: &member, + } + CanonicalizeSystemObjectField(f) + assert.Equal(t, "channel", f.TargetType) + assert.Equal(t, "ch1", f.TargetID) + assert.Equal(t, model.PermissionLevelMember, *f.PermissionField) + assert.Equal(t, model.PermissionLevelMember, *f.PermissionValues) + assert.Equal(t, model.PermissionLevelMember, *f.PermissionOptions) + }) + + t.Run("idempotent", func(t *testing.T) { + f := &model.PropertyField{ + ObjectType: model.PropertyFieldObjectTypeSystem, + TargetType: "channel", + TargetID: "ch1", + } + CanonicalizeSystemObjectField(f) + first := *f + CanonicalizeSystemObjectField(f) + assert.Equal(t, first.TargetType, f.TargetType) + assert.Equal(t, first.TargetID, f.TargetID) + }) + + t.Run("nil field: no panic", func(t *testing.T) { + assert.NotPanics(t, func() { + CanonicalizeSystemObjectField(nil) + }) + }) +} diff --git a/server/channels/app/property_field_test.go b/server/channels/app/property_field_test.go index 1ae94ddea97..c4e3fe23391 100644 --- a/server/channels/app/property_field_test.go +++ b/server/channels/app/property_field_test.go @@ -13,13 +13,23 @@ import ( "github.com/stretchr/testify/require" ) +// registerTestPropertyGroup creates a fresh, unmanaged PSAv2 property group +// for tests that exercise generic PropertyField CRUD. +func registerTestPropertyGroup(tb testing.TB, th *TestHelper) string { + tb.Helper() + group, appErr := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{ + Name: "test_" + model.NewId(), + Version: model.PropertyGroupVersionV2, + }) + require.Nil(tb, appErr) + return group.ID +} + func TestCreatePropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - group, appErr := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{Name: "test_create_field_v2_group", Version: model.PropertyGroupVersionV2}) - require.Nil(t, appErr) - groupID := group.ID + groupID := registerTestPropertyGroup(t, th) t.Run("should create a non-protected field without bypass", func(t *testing.T) { field := &model.PropertyField{ @@ -118,9 +128,7 @@ func TestUpdatePropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - group, appErr2 := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{Name: "test_update_field_v2_group", Version: model.PropertyGroupVersionV2}) - require.Nil(t, appErr2) - groupID := group.ID + groupID := registerTestPropertyGroup(t, th) t.Run("should update a non-protected field without bypass", func(t *testing.T) { field := &model.PropertyField{ @@ -134,7 +142,7 @@ func TestUpdatePropertyField(t *testing.T) { require.Nil(t, appErr) created.Name = "Updated Field Name" - updated, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") require.Nil(t, appErr) assert.Equal(t, "Updated Field Name", updated.Name) }) @@ -155,7 +163,7 @@ func TestUpdatePropertyField(t *testing.T) { require.Nil(t, appErr) created.Name = "Attempted Update" - updated, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") require.NotNil(t, appErr) assert.Nil(t, updated) assert.Equal(t, "app.property_field.update.protected.app_error", appErr.Id) @@ -178,7 +186,7 @@ func TestUpdatePropertyField(t *testing.T) { require.Nil(t, appErr) created.Name = "Successfully Updated Protected" - updated, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, true, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, true, "") require.Nil(t, appErr) assert.Equal(t, "Successfully Updated Protected", updated.Name) }) @@ -196,7 +204,7 @@ func TestUpdatePropertyField(t *testing.T) { // Try to update with empty name (invalid) created.Name = "" - updated, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") require.NotNil(t, appErr) assert.Nil(t, updated) }) @@ -206,9 +214,7 @@ func TestUpdatePropertyFields(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - group, appErr2 := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{Name: "test_update_fields_v2_group", Version: model.PropertyGroupVersionV2}) - require.Nil(t, appErr2) - groupID := group.ID + groupID := registerTestPropertyGroup(t, th) t.Run("should update multiple non-protected fields without bypass", func(t *testing.T) { field1 := &model.PropertyField{ @@ -234,7 +240,7 @@ func TestUpdatePropertyFields(t *testing.T) { created1.Name = "Updated Batch 1" created2.Name = "Updated Batch 2" - updated, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{created1, created2}, false, "") + updated, _, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{created1, created2}, false, "") require.Nil(t, appErr) require.Len(t, updated, 2) }) @@ -267,7 +273,7 @@ func TestUpdatePropertyFields(t *testing.T) { createdNonProtected.Name = "Updated Non-Protected" createdProtected.Name = "Updated Protected" - updated, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdNonProtected, createdProtected}, false, "") + updated, _, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdNonProtected, createdProtected}, false, "") require.NotNil(t, appErr) assert.Nil(t, updated) assert.Equal(t, "app.property_field.update.protected.app_error", appErr.Id) @@ -311,7 +317,7 @@ func TestUpdatePropertyFields(t *testing.T) { createdNonProtected.Name = "Bypass Updated Non-Protected" createdProtected.Name = "Bypass Updated Protected" - updated, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdNonProtected, createdProtected}, true, "") + updated, _, appErr := th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdNonProtected, createdProtected}, true, "") require.Nil(t, appErr) require.Len(t, updated, 2) }) @@ -344,7 +350,7 @@ func TestUpdatePropertyFields(t *testing.T) { createdMain.Name = "Updated Main" createdOther.Name = "Updated Other" - _, appErr = th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdMain, createdOther}, false, "") + _, _, appErr = th.App.UpdatePropertyFields(th.Context, groupID, []*model.PropertyField{createdMain, createdOther}, false, "") require.NotNil(t, appErr) // Verify neither field was updated @@ -454,7 +460,7 @@ func TestUpdatePropertyFieldVersionEnforcement(t *testing.T) { // Attempt to update it as a v2 field (add ObjectType to make it v2) created.ObjectType = model.PropertyFieldObjectTypeUser created.TargetType = string(model.PropertyFieldTargetLevelSystem) - updated, appErr := th.App.UpdatePropertyField(th.Context, v1Group.ID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, v1Group.ID, created, false, "") require.NotNil(t, appErr) assert.Nil(t, updated) assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) @@ -478,7 +484,7 @@ func TestUpdatePropertyFieldVersionEnforcement(t *testing.T) { // Attempt to update it as a v1 field (remove ObjectType to make it v1) created.ObjectType = "" created.TargetType = "user" - updated, appErr := th.App.UpdatePropertyField(th.Context, v2Group.ID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, v2Group.ID, created, false, "") require.NotNil(t, appErr) assert.Nil(t, updated) assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) @@ -498,7 +504,7 @@ func TestUpdatePropertyFieldVersionEnforcement(t *testing.T) { require.Nil(t, appErr) created.Name = "V1 Field Updated" - updated, appErr := th.App.UpdatePropertyField(th.Context, v1Group.ID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, v1Group.ID, created, false, "") require.Nil(t, appErr) assert.Equal(t, "V1 Field Updated", updated.Name) }) @@ -518,7 +524,7 @@ func TestUpdatePropertyFieldVersionEnforcement(t *testing.T) { require.Nil(t, appErr) created.Name = "V2 Field Updated" - updated, appErr := th.App.UpdatePropertyField(th.Context, v2Group.ID, created, false, "") + updated, _, appErr := th.App.UpdatePropertyField(th.Context, v2Group.ID, created, false, "") require.Nil(t, appErr) assert.Equal(t, "V2 Field Updated", updated.Name) }) @@ -528,9 +534,7 @@ func TestDeletePropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - group, appErr2 := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{Name: "test_delete_field_v2_group", Version: model.PropertyGroupVersionV2}) - require.Nil(t, appErr2) - groupID := group.ID + groupID := registerTestPropertyGroup(t, th) t.Run("should delete a non-protected field without bypass", func(t *testing.T) { field := &model.PropertyField{ @@ -668,14 +672,15 @@ func TestGetPropertyField(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + groupID := registerTestPropertyGroup(t, th) t.Run("should get an existing field", func(t *testing.T) { field := &model.PropertyField{ - GroupID: groupID, - Name: "Field to Get", - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: "Field to Get", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created, appErr := th.App.CreatePropertyField(th.Context, field, false, "") require.Nil(t, appErr) @@ -696,19 +701,22 @@ func TestGetPropertyFields(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + groupID := registerTestPropertyGroup(t, th) t.Run("should get multiple fields", func(t *testing.T) { field1 := &model.PropertyField{ - GroupID: groupID, - Name: "Multi Get Field 1", - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: "Multi Get Field 1", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } field2 := &model.PropertyField{ - GroupID: groupID, - Name: "Multi Get Field 2", - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: "Multi Get Field 2", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } created1, appErr := th.App.CreatePropertyField(th.Context, field1, false, "") @@ -726,14 +734,15 @@ func TestSearchPropertyFields(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) - groupID, err := th.App.CpaGroupID() - require.Nil(t, err) + groupID := registerTestPropertyGroup(t, th) t.Run("should search for fields", func(t *testing.T) { field := &model.PropertyField{ - GroupID: groupID, - Name: "Searchable Field", - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: "Searchable Field", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), } _, appErr := th.App.CreatePropertyField(th.Context, field, false, "") require.Nil(t, appErr) @@ -747,3 +756,281 @@ func TestSearchPropertyFields(t *testing.T) { assert.NotEmpty(t, results) }) } + +func TestCreatePropertyField_SystemCanonicalization(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + t.Run("system object: TargetType+TargetID and Permission* are canonicalized", func(t *testing.T) { + member := model.PermissionLevelMember + field := &model.PropertyField{ + GroupID: groupID, + Name: "System Canonicalize", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeSystem, + TargetType: "channel", + TargetID: model.NewId(), + PermissionField: &member, + PermissionValues: &member, + PermissionOptions: &member, + } + + created, appErr := th.App.CreatePropertyField(th.Context, field, false, "") + require.Nil(t, appErr) + assert.Equal(t, string(model.PropertyFieldTargetLevelSystem), created.TargetType) + assert.Empty(t, created.TargetID) + require.NotNil(t, created.PermissionField) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionField) + require.NotNil(t, created.PermissionValues) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionValues) + require.NotNil(t, created.PermissionOptions) + assert.Equal(t, model.PermissionLevelSysadmin, *created.PermissionOptions) + }) +} + +func TestCreatePropertyField_TrimName(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + t.Run("trims whitespace around name", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: groupID, + Name: " trim-me ", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + + created, appErr := th.App.CreatePropertyField(th.Context, field, false, "") + require.Nil(t, appErr) + assert.Equal(t, "trim-me", created.Name) + }) +} + +func TestUpdatePropertyField_TrimNameOnUpdate(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + t.Run("trims whitespace on update", func(t *testing.T) { + field := &model.PropertyField{ + GroupID: groupID, + Name: "Trim Update", + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + created, appErr := th.App.CreatePropertyField(th.Context, field, false, "") + require.Nil(t, appErr) + + created.Name = " trimmed-on-update " + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, created, false, "") + require.Nil(t, appErr) + assert.Equal(t, "trimmed-on-update", updated.Name) + }) +} + +func TestUpdatePropertyField_LinkedFieldInvariants(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + makeLinkedPair := func(t *testing.T) (template, linked *model.PropertyField) { + t.Helper() + tmpl := &model.PropertyField{ + GroupID: groupID, + Name: "tmpl-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeTemplate, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": model.NewId(), "name": "opt1"}, + }, + }, + } + createdTmpl, appErr := th.App.CreatePropertyField(th.Context, tmpl, false, "") + require.Nil(t, appErr) + + linkedID := createdTmpl.ID + linkedField := &model.PropertyField{ + GroupID: groupID, + Name: "linked-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + LinkedFieldID: &linkedID, + } + createdLinked, appErr := th.App.CreatePropertyField(th.Context, linkedField, false, "") + require.Nil(t, appErr) + return createdTmpl, createdLinked + } + + t.Run("type immutable on linked field", func(t *testing.T) { + _, linked := makeLinkedPair(t) + linked.Type = model.PropertyFieldTypeText + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, linked, false, "") + require.NotNil(t, appErr) + assert.Nil(t, updated) + assert.Equal(t, "app.property_field.update.linked_type_change.app_error", appErr.Id) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + }) + + t.Run("options immutable on linked field", func(t *testing.T) { + _, linked := makeLinkedPair(t) + linked.Attrs = model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": model.NewId(), "name": "different"}, + }, + } + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, linked, false, "") + require.NotNil(t, appErr) + assert.Nil(t, updated) + assert.Equal(t, "app.property_field.update.linked_options_change.app_error", appErr.Id) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + }) + + t.Run("link target immutable: cannot change to different target", func(t *testing.T) { + altTmpl, linked := makeLinkedPair(t) + // Create another template to point to + _ = altTmpl + newTmpl := &model.PropertyField{ + GroupID: groupID, + Name: "tmpl-alt-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeTemplate, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": model.NewId(), "name": "x"}, + }, + }, + } + createdNew, appErr := th.App.CreatePropertyField(th.Context, newTmpl, false, "") + require.Nil(t, appErr) + + newID := createdNew.ID + linked.LinkedFieldID = &newID + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, linked, false, "") + require.NotNil(t, appErr) + assert.Nil(t, updated) + assert.Equal(t, "app.property_field.update.cannot_change_link_target.app_error", appErr.Id) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + }) + + t.Run("cannot link a previously-unlinked field", func(t *testing.T) { + unlinked := &model.PropertyField{ + GroupID: groupID, + Name: "unlinked-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdUnlinked, appErr := th.App.CreatePropertyField(th.Context, unlinked, false, "") + require.Nil(t, appErr) + + // Create a template to link to + tmpl := &model.PropertyField{ + GroupID: groupID, + Name: "tmpl-late-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeTemplate, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdTmpl, appErr := th.App.CreatePropertyField(th.Context, tmpl, false, "") + require.Nil(t, appErr) + tID := createdTmpl.ID + + createdUnlinked.LinkedFieldID = &tID + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, createdUnlinked, false, "") + require.NotNil(t, appErr) + assert.Nil(t, updated) + assert.Equal(t, "app.property_field.update.cannot_link_existing.app_error", appErr.Id) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + }) +} + +func TestUpdatePropertyField_LinkedFieldNoOpPatchOK(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + t.Run("setting Type to current value on a linked field passes", func(t *testing.T) { + // Build template + linked + tmpl := &model.PropertyField{ + GroupID: groupID, + Name: "tmpl-noop-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeTemplate, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{ + model.PropertyFieldAttributeOptions: []map[string]any{ + {"id": model.NewId(), "name": "n"}, + }, + }, + } + createdTmpl, appErr := th.App.CreatePropertyField(th.Context, tmpl, false, "") + require.Nil(t, appErr) + linkedID := createdTmpl.ID + + linked := &model.PropertyField{ + GroupID: groupID, + Name: "linked-noop-" + model.NewId(), + Type: model.PropertyFieldTypeSelect, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + LinkedFieldID: &linkedID, + } + createdLinked, appErr := th.App.CreatePropertyField(th.Context, linked, false, "") + require.Nil(t, appErr) + + // No-op update: Type unchanged. + createdLinked.Name = "linked-renamed" + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, createdLinked, false, "") + require.Nil(t, appErr) + assert.Equal(t, "linked-renamed", updated.Name) + }) +} + +func TestUpdatePropertyField_LinkedFieldUnlinkAllowed(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + t.Run("plugin path: setting LinkedFieldID = nil on a linked field unlinks it", func(t *testing.T) { + tmpl := &model.PropertyField{ + GroupID: groupID, + Name: "tmpl-unlink-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeTemplate, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdTmpl, appErr := th.App.CreatePropertyField(th.Context, tmpl, false, "") + require.Nil(t, appErr) + linkedID := createdTmpl.ID + + linked := &model.PropertyField{ + GroupID: groupID, + Name: "linked-unlink-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + LinkedFieldID: &linkedID, + } + createdLinked, appErr := th.App.CreatePropertyField(th.Context, linked, false, "") + require.Nil(t, appErr) + + createdLinked.LinkedFieldID = nil + updated, _, appErr := th.App.UpdatePropertyField(th.Context, groupID, createdLinked, false, "") + require.Nil(t, appErr) + assert.Nil(t, updated.LinkedFieldID) + }) +} diff --git a/server/channels/app/property_value.go b/server/channels/app/property_value.go index 56238be938b..8892deb0058 100644 --- a/server/channels/app/property_value.go +++ b/server/channels/app/property_value.go @@ -42,9 +42,13 @@ func (a *App) CreatePropertyValue(rctx request.CTX, value *model.PropertyValue) if value == nil { return nil, model.NewAppError("CreatePropertyValue", "app.property_value.invalid_input.app_error", nil, "property value is required", http.StatusBadRequest) } + value.Value = model.SanitizePropertyValue(value.Value) createdValue, err := a.Srv().propertyService.CreatePropertyValue(rctx, value) if err != nil { + if appErr := mapPropertyServiceError("CreatePropertyValue", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("CreatePropertyValue", "app.property_value.create.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return createdValue, nil @@ -55,9 +59,15 @@ func (a *App) CreatePropertyValues(rctx request.CTX, values []*model.PropertyVal if len(values) == 0 { return nil, model.NewAppError("CreatePropertyValues", "app.property_value.invalid_input.app_error", nil, "property values are required", http.StatusBadRequest) } + for _, v := range values { + v.Value = model.SanitizePropertyValue(v.Value) + } createdValues, err := a.Srv().propertyService.CreatePropertyValues(rctx, values) if err != nil { + if appErr := mapPropertyServiceError("CreatePropertyValues", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("CreatePropertyValues", "app.property_value.create_many.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return createdValues, nil @@ -67,6 +77,9 @@ func (a *App) CreatePropertyValues(rctx request.CTX, values []*model.PropertyVal func (a *App) GetPropertyValue(rctx request.CTX, groupID, valueID string) (*model.PropertyValue, *model.AppError) { value, err := a.Srv().propertyService.GetPropertyValue(rctx, groupID, valueID) if err != nil { + if appErr := mapPropertyServiceError("GetPropertyValue", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("GetPropertyValue", "app.property_value.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return value, nil @@ -76,6 +89,9 @@ func (a *App) GetPropertyValue(rctx request.CTX, groupID, valueID string) (*mode func (a *App) GetPropertyValues(rctx request.CTX, groupID string, ids []string) ([]*model.PropertyValue, *model.AppError) { values, err := a.Srv().propertyService.GetPropertyValues(rctx, groupID, ids) if err != nil { + if appErr := mapPropertyServiceError("GetPropertyValues", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("GetPropertyValues", "app.property_value.get_many.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return values, nil @@ -85,6 +101,9 @@ func (a *App) GetPropertyValues(rctx request.CTX, groupID string, ids []string) func (a *App) SearchPropertyValues(rctx request.CTX, groupID string, opts model.PropertyValueSearchOpts) ([]*model.PropertyValue, *model.AppError) { values, err := a.Srv().propertyService.SearchPropertyValues(rctx, groupID, opts) if err != nil { + if appErr := mapPropertyServiceError("SearchPropertyValues", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("SearchPropertyValues", "app.property_value.search.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return values, nil @@ -95,9 +114,13 @@ func (a *App) UpdatePropertyValue(rctx request.CTX, groupID string, value *model if value == nil { return nil, model.NewAppError("UpdatePropertyValue", "app.property_value.invalid_input.app_error", nil, "property value is required", http.StatusBadRequest) } + value.Value = model.SanitizePropertyValue(value.Value) updatedValue, err := a.Srv().propertyService.UpdatePropertyValue(rctx, groupID, value) if err != nil { + if appErr := mapPropertyServiceError("UpdatePropertyValue", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("UpdatePropertyValue", "app.property_value.update.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return updatedValue, nil @@ -108,9 +131,15 @@ func (a *App) UpdatePropertyValues(rctx request.CTX, groupID string, values []*m if len(values) == 0 { return nil, model.NewAppError("UpdatePropertyValues", "app.property_value.invalid_input.app_error", nil, "property values are required", http.StatusBadRequest) } + for _, v := range values { + v.Value = model.SanitizePropertyValue(v.Value) + } updatedValues, err := a.Srv().propertyService.UpdatePropertyValues(rctx, groupID, values) if err != nil { + if appErr := mapPropertyServiceError("UpdatePropertyValues", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("UpdatePropertyValues", "app.property_value.update_many.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return updatedValues, nil @@ -121,9 +150,13 @@ func (a *App) UpsertPropertyValue(rctx request.CTX, value *model.PropertyValue) if value == nil { return nil, model.NewAppError("UpsertPropertyValue", "app.property_value.invalid_input.app_error", nil, "property value is required", http.StatusBadRequest) } + value.Value = model.SanitizePropertyValue(value.Value) upsertedValue, err := a.Srv().propertyService.UpsertPropertyValue(rctx, value) if err != nil { + if appErr := mapPropertyServiceError("UpsertPropertyValue", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("UpsertPropertyValue", "app.property_value.upsert.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } return upsertedValue, nil @@ -131,14 +164,103 @@ func (a *App) UpsertPropertyValue(rctx request.CTX, value *model.PropertyValue) // UpsertPropertyValues creates or updates multiple property values. // When objectType is non-empty, WebSocket events are broadcast to notify -// clients of the updated values. +// clients of the updated values, and every referenced field is required +// to have a matching ObjectType. func (a *App) UpsertPropertyValues(rctx request.CTX, values []*model.PropertyValue, objectType, targetID, connectionID string) ([]*model.PropertyValue, *model.AppError) { if len(values) == 0 { return nil, model.NewAppError("UpsertPropertyValues", "app.property_value.invalid_input.app_error", nil, "property values are required", http.StatusBadRequest) } + // Intrinsic invariants — apply to every caller (HTTP, plugin, internal). + // Single-group invariant must run before the bulk-load below, since + // GetPropertyFields takes a single groupID. Guard values[0] explicitly + // because the per-element nil check inside the loop would otherwise be + // reached after the values[0].GroupID dereference. + if values[0] == nil { + return nil, model.NewAppError("UpsertPropertyValues", "app.property_value.invalid_input.app_error", nil, "nil property value in batch", http.StatusBadRequest) + } + groupID := values[0].GroupID + seenIDs := make(map[string]bool, len(values)) + fieldIDs := make([]string, 0, len(values)) + for _, v := range values { + if v == nil { + return nil, model.NewAppError("UpsertPropertyValues", "app.property_value.invalid_input.app_error", nil, "nil property value in batch", http.StatusBadRequest) + } + if v.GroupID != groupID { + return nil, model.NewAppError( + "UpsertPropertyValues", + "app.property_value.upsert.mixed_groups.app_error", + nil, + "all values in a batch must belong to the same group", + http.StatusBadRequest, + ) + } + if !model.IsValidId(v.FieldID) { + return nil, model.NewAppError( + "UpsertPropertyValues", + "app.property_value.upsert.invalid_field_id.app_error", + map[string]any{"FieldID": v.FieldID}, + "invalid field ID", + http.StatusBadRequest, + ) + } + if seenIDs[v.FieldID] { + return nil, model.NewAppError( + "UpsertPropertyValues", + "app.property_value.upsert.duplicate_field_id.app_error", + map[string]any{"FieldID": v.FieldID}, + "duplicate field ID in batch", + http.StatusBadRequest, + ) + } + seenIDs[v.FieldID] = true + fieldIDs = append(fieldIDs, v.FieldID) + v.Value = model.SanitizePropertyValue(v.Value) + } + + // ObjectType-mismatch check is gated on a non-empty objectType argument. + // Plugin API today always passes objectType="" and keeps its loose + // contract on this specific check. + if objectType != "" { + fields, fieldsErr := a.GetPropertyFields(rctx, groupID, fieldIDs) + if fieldsErr != nil { + return nil, fieldsErr + } + fieldByID := make(map[string]*model.PropertyField, len(fields)) + for _, f := range fields { + fieldByID[f.ID] = f + } + for _, v := range values { + f, ok := fieldByID[v.FieldID] + if !ok { + return nil, model.NewAppError( + "UpsertPropertyValues", + "app.property_value.upsert.field_not_found.app_error", + map[string]any{"FieldID": v.FieldID}, + "field not found", + http.StatusNotFound, + ) + } + if f.ObjectType != objectType { + // 404 matches the shape of a non-existent field so callers + // cannot distinguish "no such field" from "field exists but + // in a different object-type bucket". + return nil, model.NewAppError( + "UpsertPropertyValues", + "app.property_value.upsert.object_type_mismatch.app_error", + map[string]any{"FieldID": v.FieldID}, + "object type mismatch", + http.StatusNotFound, + ) + } + } + } + result, err := a.Srv().propertyService.UpsertPropertyValues(rctx, values) if err != nil { + if appErr := mapPropertyServiceError("UpsertPropertyValues", err); appErr != nil { + return nil, appErr + } return nil, model.NewAppError("UpsertPropertyValues", "app.property_value.upsert_many.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -172,6 +294,9 @@ func (a *App) DeletePropertyValue(rctx request.CTX, groupID, valueID string) *mo } if err := a.Srv().propertyService.DeletePropertyValue(rctx, groupID, valueID); err != nil { + if mappedErr := mapPropertyServiceError("DeletePropertyValue", err); mappedErr != nil { + return mappedErr + } return model.NewAppError("DeletePropertyValue", "app.property_value.delete.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -204,6 +329,9 @@ func (a *App) DeletePropertyValue(rctx request.CTX, groupID, valueID string) *mo // DeletePropertyValuesForTarget deletes all property values for a target and broadcasts a property_values_updated event. func (a *App) DeletePropertyValuesForTarget(rctx request.CTX, groupID, targetType, targetID string) *model.AppError { if err := a.Srv().propertyService.DeletePropertyValuesForTarget(rctx, groupID, targetType, targetID); err != nil { + if appErr := mapPropertyServiceError("DeletePropertyValuesForTarget", err); appErr != nil { + return appErr + } return model.NewAppError("DeletePropertyValuesForTarget", "app.property_value.delete_for_target.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -224,6 +352,9 @@ func (a *App) DeletePropertyValuesForTarget(rctx request.CTX, groupID, targetTyp // DeletePropertyValuesForField deletes all property values for a field and broadcasts a property_values_updated event. func (a *App) DeletePropertyValuesForField(rctx request.CTX, groupID, fieldID string) *model.AppError { if err := a.Srv().propertyService.DeletePropertyValuesForField(rctx, groupID, fieldID); err != nil { + if appErr := mapPropertyServiceError("DeletePropertyValuesForField", err); appErr != nil { + return appErr + } return model.NewAppError("DeletePropertyValuesForField", "app.property_value.delete_for_field.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } diff --git a/server/channels/app/property_value_test.go b/server/channels/app/property_value_test.go index 4b65660aed3..2ed0ea7dbd9 100644 --- a/server/channels/app/property_value_test.go +++ b/server/channels/app/property_value_test.go @@ -59,3 +59,103 @@ func TestResolveValueBroadcastParams(t *testing.T) { assert.Equal(t, http.StatusBadRequest, err.StatusCode) }) } + +func TestUpsertPropertyValues_Invariants(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + groupID := registerTestPropertyGroup(t, th) + + // Create a target user-typed field for the happy paths. + field := &model.PropertyField{ + GroupID: groupID, + Name: "upsert-target-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdField, appErr := th.App.CreatePropertyField(th.Context, field, false, "") + require.Nil(t, appErr) + + makeValue := func(fieldID string) *model.PropertyValue { + return &model.PropertyValue{ + TargetID: th.BasicUser.Id, + TargetType: model.PropertyFieldObjectTypeUser, + GroupID: groupID, + FieldID: fieldID, + Value: []byte("\"v\""), + CreatedBy: th.BasicUser.Id, + UpdatedBy: th.BasicUser.Id, + } + } + + t.Run("rejects duplicate FieldID", func(t *testing.T) { + v := []*model.PropertyValue{makeValue(createdField.ID), makeValue(createdField.ID)} + result, err := th.App.UpsertPropertyValues(th.Context, v, model.PropertyFieldObjectTypeUser, th.BasicUser.Id, "") + require.NotNil(t, err) + assert.Nil(t, result) + assert.Equal(t, "app.property_value.upsert.duplicate_field_id.app_error", err.Id) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + }) + + t.Run("rejects invalid FieldID", func(t *testing.T) { + v := []*model.PropertyValue{makeValue("not-an-id")} + result, err := th.App.UpsertPropertyValues(th.Context, v, model.PropertyFieldObjectTypeUser, th.BasicUser.Id, "") + require.NotNil(t, err) + assert.Nil(t, result) + assert.Equal(t, "app.property_value.upsert.invalid_field_id.app_error", err.Id) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + }) + + t.Run("rejects mixed group IDs as a clean 400", func(t *testing.T) { + altGroup, appErr := th.App.RegisterPropertyGroup(th.Context, &model.PropertyGroup{ + Name: "alt_mix_" + model.NewId(), + Version: model.PropertyGroupVersionV2, + }) + require.Nil(t, appErr) + + altField := &model.PropertyField{ + GroupID: altGroup.ID, + Name: "alt-field-" + model.NewId(), + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + } + createdAlt, appErr := th.App.CreatePropertyField(th.Context, altField, false, "") + require.Nil(t, appErr) + + v1 := makeValue(createdField.ID) + v2 := makeValue(createdAlt.ID) + v2.GroupID = altGroup.ID + result, err := th.App.UpsertPropertyValues(th.Context, []*model.PropertyValue{v1, v2}, model.PropertyFieldObjectTypeUser, th.BasicUser.Id, "") + require.NotNil(t, err) + assert.Nil(t, result) + assert.Equal(t, "app.property_value.upsert.mixed_groups.app_error", err.Id) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + }) + + t.Run("rejects ObjectType mismatch when objectType is non-empty", func(t *testing.T) { + // Field is ObjectType=user; request specifies channel. + v := []*model.PropertyValue{makeValue(createdField.ID)} + result, err := th.App.UpsertPropertyValues(th.Context, v, model.PropertyFieldObjectTypeChannel, "ch1", "") + require.NotNil(t, err) + assert.Nil(t, result) + assert.Equal(t, "app.property_value.upsert.object_type_mismatch.app_error", err.Id) + assert.Equal(t, http.StatusNotFound, err.StatusCode) + }) + + t.Run("plugin path: empty objectType skips ObjectType match", func(t *testing.T) { + // We don't actually need the upsert to succeed (target/etc may not + // satisfy schema), only to bypass the ObjectType-mismatch reject. + // Confirm by passing a wrong-typed field with objectType="" — the + // app-layer reject should not fire; any error must come from + // downstream layers, not "object_type_mismatch". + v := []*model.PropertyValue{makeValue(createdField.ID)} + _, err := th.App.UpsertPropertyValues(th.Context, v, "", "", "") + // Either succeeds, or fails for a different reason — never the + // object_type_mismatch reject. + if err != nil { + assert.NotEqual(t, "app.property_value.upsert.object_type_mismatch.app_error", err.Id) + } + }) +} diff --git a/server/channels/app/server.go b/server/channels/app/server.go index 7883817c7d9..06528291fc1 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -269,19 +269,9 @@ func NewServer(options ...Option) (*Server, error) { return nil, errors.Wrapf(err, "unable to create properties service") } - propertyAccessService := properties.NewPropertyAccessService(s.propertyService, func(pluginID string) bool { - if s.ch == nil { - return false - } - - _, err := s.ch.GetPluginStatus(pluginID) - return err == nil - }) - s.propertyService.SetPropertyAccessService(propertyAccessService) - - // Register builtin property groups after fully initializing the propertyService + // Register builtin property groups before creating hooks that reference them if err = s.propertyService.RegisterBuiltinGroups([]*model.PropertyGroup{ - {Name: model.CustomProfileAttributesPropertyGroupName, Version: model.PropertyGroupVersionV1}, + {Name: model.AccessControlPropertyGroupName, Version: model.PropertyGroupVersionV2}, {Name: model.ContentFlaggingGroupName, Version: model.PropertyGroupVersionV1}, {Name: model.ClassificationMarkingsPropertyGroupName, Version: model.PropertyGroupVersionV2}, }); err != nil { @@ -310,6 +300,64 @@ func NewServer(options ...Option) (*Server, error) { // After channel is initialized set it to the App object app := New(ServerConnector(channels)) + // Register property-service hooks AFTER s.ch is populated. The + // access-control and attribute-validation hooks capture s and use + // s.ch for plugin-status and permission lookups; registering them + // earlier leaves a window where hook invocations race against a + // nil s.ch. + cpaGroup, err := s.propertyService.Group(model.AccessControlPropertyGroupName) + if err != nil { + return nil, errors.Wrap(err, "failed to look up CPA property group") + } + + // License check hook — must run before other hooks so unlicensed + // operations are rejected early. + licenseCheckHook := properties.NewLicenseCheckHook(func() *model.License { + return s.License() + }, cpaGroup.ID) + s.propertyService.AddHook(licenseCheckHook) + + accessControlHook := properties.NewAccessControlHook(s.propertyService, func(pluginID string) bool { + _, err := s.ch.GetPluginStatus(pluginID) + return err == nil + }, cpaGroup.ID) + s.propertyService.AddHook(accessControlHook) + + // Attribute validation hook — validates visibility, sort_order on fields, + // field-type constraints on values (options, user IDs, value_type), and + // managed-flag authorization + permission level enforcement. + permChecker := func(userID string, perm *model.Permission) bool { + // Local-mode (unrestricted) sessions are tagged with + // CallerIDLocalAdmin by the HTTP layer; grant them admin + // permissions without a user lookup. + if userID == model.CallerIDLocalAdmin { + return true + } + return app.HasPermissionTo(userID, perm) + } + attrValidationHook := properties.NewAccessControlAttributeValidationHook(s.propertyService, permChecker, cpaGroup.ID) + s.propertyService.AddHook(attrValidationHook) + + // Field limit hook — enforces per-object-type and global field limits. + // Only "user" has a per-type cap today; when channel/team/post CPA fields + // are added, set their per-type caps here. Until then + // AccessControlGroupFieldLimit is the only ceiling for non-user + // object types within this group. + fieldLimitHook := properties.NewFieldLimitHook(s.propertyService) + fieldLimitHook.AddGroupLimit(cpaGroup.ID, &properties.FieldLimitConfig{ + PerObjectType: map[string]int64{ + model.PropertyFieldObjectTypeUser: 20, + }, + GlobalLimit: model.AccessControlGroupFieldLimit, + }) + s.propertyService.AddHook(fieldLimitHook) + + // Type-change value cleanup — registered last so the field write has + // passed every other gate (license, access control, validation, limit) + // before we cascade-delete dependent values. PostUpdate hooks run after + // the store write succeeds. + s.propertyService.AddHook(properties.NewTypeChangeValueCleanupHook(s.propertyService)) + // ------------------------------------------------------------------------- // Everything below this is not order sensitive and safe to be moved around. // If you are adding a new field that is non-channels specific, please add diff --git a/server/channels/db/migrations/migrations.list b/server/channels/db/migrations/migrations.list index 71a28e96008..72f6bccb373 100644 --- a/server/channels/db/migrations/migrations.list +++ b/server/channels/db/migrations/migrations.list @@ -347,3 +347,7 @@ channels/db/migrations/postgres/000174_set_posts_statistics_targets.down.sql channels/db/migrations/postgres/000174_set_posts_statistics_targets.up.sql channels/db/migrations/postgres/000175_add_board_channel_types.down.sql channels/db/migrations/postgres/000175_add_board_channel_types.up.sql +channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql +channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql +channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql +channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql diff --git a/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql b/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql new file mode 100644 index 00000000000..a82b8414260 --- /dev/null +++ b/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql @@ -0,0 +1,16 @@ +-- Rename the group back to custom_profile_attributes and revert to V1. +UPDATE PropertyGroups +SET Name = 'custom_profile_attributes', + Version = 1 +WHERE Name = 'access_control'; + +-- Revert field metadata to the pre-migration state. +UPDATE PropertyFields +SET ObjectType = '', + TargetType = '', + PermissionField = NULL, + PermissionValues = NULL, + PermissionOptions = NULL +WHERE GroupID = (SELECT ID FROM PropertyGroups WHERE Name = 'custom_profile_attributes') + AND ObjectType = 'user' + AND TargetType = 'system'; diff --git a/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql b/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql new file mode 100644 index 00000000000..c7ec3832c23 --- /dev/null +++ b/server/channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql @@ -0,0 +1,22 @@ +-- Update all fields belonging to the CPA group before renaming it. +-- Row-level locks only; bounded by the per-group field limit (~200 max). +-- PermissionValues is 'sysadmin' for admin-managed fields, 'member' for all +-- others so that regular users can write their own profile values through the +-- generic property API. +UPDATE PropertyFields +SET ObjectType = 'user', + TargetType = 'system', + PermissionField = 'sysadmin', + PermissionValues = (CASE + WHEN Attrs->>'managed' = 'admin' THEN 'sysadmin' + ELSE 'member' + END)::permission_level, + PermissionOptions = 'sysadmin' +WHERE GroupID = (SELECT ID FROM PropertyGroups WHERE Name = 'custom_profile_attributes'); + +-- Rename the group and bump it to PSAv2. Single-row update, non-blocking. +-- The Version column was added in 000170; existing CPA groups default to V1. +UPDATE PropertyGroups +SET Name = 'access_control', + Version = 2 +WHERE Name = 'custom_profile_attributes'; diff --git a/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql b/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql new file mode 100644 index 00000000000..7885253b9c2 --- /dev/null +++ b/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql @@ -0,0 +1,30 @@ +-- Restore the materialized view without the ObjectType filter (000137 version). +DROP MATERIALIZED VIEW IF EXISTS AttributeView; + +CREATE MATERIALIZED VIEW IF NOT EXISTS AttributeView AS +SELECT + pv.GroupID, + pv.TargetID, + pv.TargetType, + jsonb_object_agg( + pf.Name, + CASE + WHEN pf.Type = 'select' THEN ( + SELECT to_jsonb(options.name) + FROM jsonb_to_recordset(pf.Attrs->'options') AS options(id text, name text) + WHERE options.id = pv.Value #>> '{}' + LIMIT 1 + ) + WHEN pf.Type = 'multiselect' AND jsonb_typeof(pv.Value) = 'array' THEN ( + SELECT jsonb_agg(option_names.name) + FROM jsonb_array_elements_text(pv.Value) AS option_id + JOIN jsonb_to_recordset(pf.Attrs->'options') AS option_names(id text, name text) + ON option_id = option_names.id + ) + ELSE pv.Value + END + ) AS Attributes +FROM PropertyValues pv +LEFT JOIN PropertyFields pf ON pf.ID = pv.FieldID +WHERE (pv.DeleteAt = 0 OR pv.DeleteAt IS NULL) AND (pf.DeleteAt = 0 OR pf.DeleteAt IS NULL) +GROUP BY pv.GroupID, pv.TargetID, pv.TargetType; diff --git a/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql b/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql new file mode 100644 index 00000000000..be06808a004 --- /dev/null +++ b/server/channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql @@ -0,0 +1,35 @@ +-- Recreate the materialized view with an ObjectType = 'user' filter so it +-- only materializes user-scoped attributes. Split from 000172 so the row +-- locks taken by that migration's UPDATEs aren't held for the duration of +-- the matview scan. Same drop+create pattern as migration 000137. +DROP MATERIALIZED VIEW IF EXISTS AttributeView; + +CREATE MATERIALIZED VIEW IF NOT EXISTS AttributeView AS +SELECT + pv.GroupID, + pv.TargetID, + pv.TargetType, + jsonb_object_agg( + pf.Name, + CASE + WHEN pf.Type = 'select' THEN ( + SELECT to_jsonb(options.name) + FROM jsonb_to_recordset(pf.Attrs->'options') AS options(id text, name text) + WHERE options.id = pv.Value #>> '{}' + LIMIT 1 + ) + WHEN pf.Type = 'multiselect' AND jsonb_typeof(pv.Value) = 'array' THEN ( + SELECT jsonb_agg(option_names.name) + FROM jsonb_array_elements_text(pv.Value) AS option_id + JOIN jsonb_to_recordset(pf.Attrs->'options') AS option_names(id text, name text) + ON option_id = option_names.id + ) + ELSE pv.Value + END + ) AS Attributes +FROM PropertyValues pv +LEFT JOIN PropertyFields pf ON pf.ID = pv.FieldID +WHERE (pv.DeleteAt = 0 OR pv.DeleteAt IS NULL) + AND (pf.DeleteAt = 0 OR pf.DeleteAt IS NULL) + AND pf.ObjectType = 'user' +GROUP BY pv.GroupID, pv.TargetID, pv.TargetType; diff --git a/server/channels/store/sqlstore/migration_000172_test.go b/server/channels/store/sqlstore/migration_000172_test.go new file mode 100644 index 00000000000..7dba1969f71 --- /dev/null +++ b/server/channels/store/sqlstore/migration_000172_test.go @@ -0,0 +1,331 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/channels/db" +) + +func readMigrationSQL(t *testing.T, filename string) string { + t.Helper() + data, err := db.Assets().ReadFile("migrations/postgres/" + filename) + require.NoError(t, err, "failed to read migration file %s", filename) + return string(data) +} + +func TestMigration000172(t *testing.T) { + logger := mlog.CreateTestLogger(t) + + settings, err := makeSqlSettings(model.DatabaseDriverPostgres) + if err != nil { + t.Skip(err) + } + + store, err := New(*settings, logger, nil) + require.NoError(t, err) + defer store.Close() + + master := store.GetMaster() + + upSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.up.sql") + downSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.down.sql") + + // Insert a group simulating pre-migration CPA state. + groupID := model.NewId() + _, err = master.Exec("INSERT INTO PropertyGroups (ID, Name) VALUES (?, ?)", groupID, "custom_profile_attributes") + require.NoError(t, err) + + t.Cleanup(func() { + master.Exec("DELETE FROM PropertyValues WHERE GroupID = ?", groupID) //nolint:errcheck + master.Exec("DELETE FROM PropertyFields WHERE GroupID = ?", groupID) //nolint:errcheck + master.Exec("DELETE FROM PropertyGroups WHERE ID = ?", groupID) //nolint:errcheck + }) + + now := model.GetMillis() + + // Insert active fields with old format (no ObjectType, no permissions). + // fieldID1 and fieldID2 are non-managed; fieldID3 is admin-managed. + fieldID1 := model.NewId() + fieldID2 := model.NewId() + fieldID3 := model.NewId() + for _, f := range []struct { + id string + name string + ftype string + attrs string + }{ + {fieldID1, "Text Field", "text", `{"visibility":"always","sort_order":1}`}, + {fieldID2, "Select Field", "select", `{"options":[{"id":"opt1","name":"Option 1"}]}`}, + {fieldID3, "Admin Managed Field", "text", `{"visibility":"always","sort_order":3,"managed":"admin"}`}, + } { + _, err = master.Exec( + `INSERT INTO PropertyFields + (ID, GroupID, Name, Type, Attrs, TargetID, TargetType, ObjectType, CreateAt, UpdateAt, DeleteAt, Protected) + VALUES (?, ?, ?, ?, ?::jsonb, '', '', '', ?, ?, 0, false)`, + f.id, groupID, f.name, f.ftype, f.attrs, now, now, + ) + require.NoError(t, err, "inserting field %s", f.name) + } + + // Insert a soft-deleted field to verify all fields are migrated. + deletedFieldID := model.NewId() + _, err = master.Exec( + `INSERT INTO PropertyFields + (ID, GroupID, Name, Type, Attrs, TargetID, TargetType, ObjectType, CreateAt, UpdateAt, DeleteAt, Protected) + VALUES (?, ?, 'Deleted Field', 'text', '{}'::jsonb, '', '', '', ?, ?, ?, false)`, + deletedFieldID, groupID, now, now, now, + ) + require.NoError(t, err) + + // Insert a property value. + valueID := model.NewId() + targetUserID := model.NewId() + _, err = master.Exec( + `INSERT INTO PropertyValues + (ID, TargetID, TargetType, GroupID, FieldID, Value, CreateAt, UpdateAt, DeleteAt) + VALUES (?, ?, 'user', ?, ?, '"hello"'::jsonb, ?, ?, 0)`, + valueID, targetUserID, groupID, fieldID1, now, now, + ) + require.NoError(t, err) + + // ---- Run UP migration ---- + _, err = master.ExecNoTimeout(upSQL) + require.NoError(t, err, "up migration should succeed") + + // Verify: group renamed. + var groupName string + require.NoError(t, master.Get(&groupName, "SELECT Name FROM PropertyGroups WHERE ID = ?", groupID)) + assert.Equal(t, "access_control", groupName) + + // Verify: all fields (including soft-deleted) have new metadata. + // Non-managed fields get PermissionValues = 'member'. + // Admin-managed fields get PermissionValues = 'sysadmin'. + for _, tc := range []struct { + id string + label string + expectedPermissionValues string + }{ + {fieldID1, "non-managed text field", "member"}, + {fieldID2, "non-managed select field", "member"}, + {fieldID3, "admin-managed field", "sysadmin"}, + {deletedFieldID, "soft-deleted non-managed field", "member"}, + } { + var f struct { + ObjectType string `db:"objecttype"` + TargetType string `db:"targettype"` + PermissionField sql.NullString `db:"permissionfield"` + PermissionValues sql.NullString `db:"permissionvalues"` + PermissionOptions sql.NullString `db:"permissionoptions"` + } + require.NoError(t, master.Get(&f, "SELECT ObjectType, TargetType, PermissionField, PermissionValues, PermissionOptions FROM PropertyFields WHERE ID = ?", tc.id)) + assert.Equal(t, "user", f.ObjectType, "%s ObjectType", tc.label) + assert.Equal(t, "system", f.TargetType, "%s TargetType", tc.label) + assert.True(t, f.PermissionField.Valid, "%s PermissionField should be set", tc.label) + assert.Equal(t, "sysadmin", f.PermissionField.String, "%s PermissionField", tc.label) + assert.True(t, f.PermissionValues.Valid, "%s PermissionValues should be set", tc.label) + assert.Equal(t, tc.expectedPermissionValues, f.PermissionValues.String, "%s PermissionValues", tc.label) + assert.True(t, f.PermissionOptions.Valid, "%s PermissionOptions should be set", tc.label) + assert.Equal(t, "sysadmin", f.PermissionOptions.String, "%s PermissionOptions", tc.label) + } + + // Verify: property value is unchanged (GroupID still references the same ID). + var val struct { + GroupID string `db:"groupid"` + TargetID string `db:"targetid"` + TargetType string `db:"targettype"` + } + require.NoError(t, master.Get(&val, "SELECT GroupID, TargetID, TargetType FROM PropertyValues WHERE ID = ?", valueID)) + assert.Equal(t, groupID, val.GroupID, "value GroupID should be unchanged") + assert.Equal(t, targetUserID, val.TargetID, "value TargetID should be unchanged") + assert.Equal(t, "user", val.TargetType, "value TargetType should be unchanged") + + // Verify: AttributeView exists and includes the ObjectType filter (user-type fields only). + var viewDef string + err = master.Get(&viewDef, "SELECT definition FROM pg_matviews WHERE matviewname = 'attributeview'") + require.NoError(t, err, "AttributeView should exist") + assert.Contains(t, viewDef, "pf.objecttype", "view definition should filter by pf.ObjectType") + + // Verify: materialized view contains expected data after refresh. + _, err = master.ExecNoTimeout("REFRESH MATERIALIZED VIEW AttributeView") + require.NoError(t, err, "refreshing AttributeView should succeed") + + var viewRow struct { + GroupID string `db:"groupid"` + TargetID string `db:"targetid"` + TargetType string `db:"targettype"` + Attributes []byte `db:"attributes"` + } + err = master.Get(&viewRow, "SELECT GroupID, TargetID, TargetType, Attributes FROM AttributeView WHERE TargetID = ?", targetUserID) + require.NoError(t, err, "AttributeView should contain a row for the target user") + assert.Equal(t, groupID, viewRow.GroupID) + assert.Equal(t, targetUserID, viewRow.TargetID) + assert.Equal(t, "user", viewRow.TargetType) + + // The text field value "hello" should appear under the field name "Text Field". + var attrs map[string]json.RawMessage + require.NoError(t, json.Unmarshal(viewRow.Attributes, &attrs)) + assert.JSONEq(t, `"hello"`, string(attrs["Text Field"]), "text field value should be materialized") + + // ---- Run DOWN migration ---- + _, err = master.ExecNoTimeout(downSQL) + require.NoError(t, err, "down migration should succeed") + + // Verify: group name reverted. + require.NoError(t, master.Get(&groupName, "SELECT Name FROM PropertyGroups WHERE ID = ?", groupID)) + assert.Equal(t, "custom_profile_attributes", groupName) + + // Verify: fields reverted. + for _, fid := range []string{fieldID1, fieldID2, fieldID3, deletedFieldID} { + var f struct { + ObjectType string `db:"objecttype"` + TargetType string `db:"targettype"` + PermissionField sql.NullString `db:"permissionfield"` + PermissionValues sql.NullString `db:"permissionvalues"` + PermissionOptions sql.NullString `db:"permissionoptions"` + } + require.NoError(t, master.Get(&f, "SELECT ObjectType, TargetType, PermissionField, PermissionValues, PermissionOptions FROM PropertyFields WHERE ID = ?", fid)) + assert.Equal(t, "", f.ObjectType, "field %s ObjectType should revert", fid) + assert.Equal(t, "", f.TargetType, "field %s TargetType should revert", fid) + assert.False(t, f.PermissionField.Valid, "field %s PermissionField should be NULL", fid) + assert.False(t, f.PermissionValues.Valid, "field %s PermissionValues should be NULL", fid) + assert.False(t, f.PermissionOptions.Valid, "field %s PermissionOptions should be NULL", fid) + } + + // Verify: value still unchanged after down migration. + require.NoError(t, master.Get(&val, "SELECT GroupID, TargetID, TargetType FROM PropertyValues WHERE ID = ?", valueID)) + assert.Equal(t, groupID, val.GroupID, "value GroupID should remain unchanged after down") +} + +func TestMigration000172DownPreservesNonUserFields(t *testing.T) { + logger := mlog.CreateTestLogger(t) + + settings, err := makeSqlSettings(model.DatabaseDriverPostgres) + if err != nil { + t.Skip(err) + } + + store, err := New(*settings, logger, nil) + require.NoError(t, err) + defer store.Close() + + master := store.GetMaster() + + upSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.up.sql") + downSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.down.sql") + + groupID := model.NewId() + _, err = master.Exec("INSERT INTO PropertyGroups (ID, Name) VALUES (?, ?)", groupID, "custom_profile_attributes") + require.NoError(t, err) + + t.Cleanup(func() { + master.Exec("DELETE FROM PropertyFields WHERE GroupID = ?", groupID) //nolint:errcheck + master.Exec("DELETE FROM PropertyGroups WHERE ID = ?", groupID) //nolint:errcheck + }) + + now := model.GetMillis() + + // Insert a legacy user field that the up migration will touch. + userFieldID := model.NewId() + _, err = master.Exec( + `INSERT INTO PropertyFields + (ID, GroupID, Name, Type, Attrs, TargetID, TargetType, ObjectType, CreateAt, UpdateAt, DeleteAt, Protected) + VALUES (?, ?, 'Legacy User Field', 'text', '{}'::jsonb, '', '', '', ?, ?, 0, false)`, + userFieldID, groupID, now, now, + ) + require.NoError(t, err) + + // Run UP migration — legacy user field gets ObjectType='user', TargetType='system'. + _, err = master.ExecNoTimeout(upSQL) + require.NoError(t, err, "up migration should succeed") + + // Simulate a post-migration channel-scoped field created via the + // generic property API against the (now renamed) access_control + // group. + channelFieldID := model.NewId() + channelTargetID := model.NewId() + _, err = master.Exec( + `INSERT INTO PropertyFields + (ID, GroupID, Name, Type, Attrs, TargetID, TargetType, ObjectType, PermissionField, PermissionValues, PermissionOptions, CreateAt, UpdateAt, DeleteAt, Protected) + VALUES (?, ?, 'Channel Classification', 'select', '{}'::jsonb, ?, 'channel', 'channel', 'sysadmin', 'member', 'sysadmin', ?, ?, 0, false)`, + channelFieldID, groupID, channelTargetID, now, now, + ) + require.NoError(t, err) + + // Run DOWN migration — must revert only user/system fields, not the channel one. + _, err = master.ExecNoTimeout(downSQL) + require.NoError(t, err, "down migration should succeed") + + // The original user field reverts to legacy metadata. + var userField struct { + ObjectType string `db:"objecttype"` + TargetType string `db:"targettype"` + PermissionField sql.NullString `db:"permissionfield"` + PermissionValues sql.NullString `db:"permissionvalues"` + PermissionOptions sql.NullString `db:"permissionoptions"` + } + require.NoError(t, master.Get(&userField, "SELECT ObjectType, TargetType, PermissionField, PermissionValues, PermissionOptions FROM PropertyFields WHERE ID = ?", userFieldID)) + assert.Equal(t, "", userField.ObjectType, "user field ObjectType should revert") + assert.Equal(t, "", userField.TargetType, "user field TargetType should revert") + assert.False(t, userField.PermissionField.Valid, "user field PermissionField should be NULL") + + // The post-migration channel field keeps its PSAv2 metadata intact. + var channelField struct { + ObjectType string `db:"objecttype"` + TargetType string `db:"targettype"` + TargetID string `db:"targetid"` + PermissionField sql.NullString `db:"permissionfield"` + PermissionValues sql.NullString `db:"permissionvalues"` + PermissionOptions sql.NullString `db:"permissionoptions"` + } + require.NoError(t, master.Get(&channelField, "SELECT ObjectType, TargetType, TargetID, PermissionField, PermissionValues, PermissionOptions FROM PropertyFields WHERE ID = ?", channelFieldID)) + assert.Equal(t, "channel", channelField.ObjectType, "channel field ObjectType must survive rollback") + assert.Equal(t, "channel", channelField.TargetType, "channel field TargetType must survive rollback") + assert.Equal(t, channelTargetID, channelField.TargetID, "channel field TargetID must survive rollback") + assert.True(t, channelField.PermissionField.Valid, "channel field PermissionField must survive rollback") + assert.Equal(t, "sysadmin", channelField.PermissionField.String) + assert.True(t, channelField.PermissionValues.Valid) + assert.Equal(t, "member", channelField.PermissionValues.String) + assert.True(t, channelField.PermissionOptions.Valid) + assert.Equal(t, "sysadmin", channelField.PermissionOptions.String) +} + +func TestMigration000172NoOpOnFreshDB(t *testing.T) { + logger := mlog.CreateTestLogger(t) + + settings, err := makeSqlSettings(model.DatabaseDriverPostgres) + if err != nil { + t.Skip(err) + } + + store, err := New(*settings, logger, nil) + require.NoError(t, err) + defer store.Close() + + master := store.GetMaster() + + upSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.up.sql") + downSQL := readMigrationSQL(t, "000176_migrate_cpa_to_access_control.down.sql") + + // On a fresh database with no CPA group, both up and down should be + // safe no-ops (the UPDATE statements match zero rows). + _, err = master.ExecNoTimeout(upSQL) + assert.NoError(t, err, "up migration should be a safe no-op on fresh DB") + + // Even with no CPA data, the view should be (re)created. + var viewExists bool + require.NoError(t, master.Get(&viewExists, "SELECT EXISTS (SELECT 1 FROM pg_matviews WHERE matviewname = 'attributeview')")) + assert.True(t, viewExists, "AttributeView should exist after up migration on fresh DB") + + _, err = master.ExecNoTimeout(downSQL) + assert.NoError(t, err, "down migration should be a safe no-op on fresh DB") +} diff --git a/server/channels/store/sqlstore/property_field_store.go b/server/channels/store/sqlstore/property_field_store.go index adaa4fd8cee..9bd23e0125e 100644 --- a/server/channels/store/sqlstore/property_field_store.go +++ b/server/channels/store/sqlstore/property_field_store.go @@ -67,7 +67,7 @@ func (s *SqlPropertyFieldStore) Get(ctx context.Context, groupID, id string) (*m var field model.PropertyField if err := s.DBXFromContext(ctx).GetBuilder(&field, builder); err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, store.NewErrNotFound("PropertyField", id) } return nil, errors.Wrap(err, "property_field_get_select") @@ -85,6 +85,9 @@ func (s *SqlPropertyFieldStore) GetFieldByName(ctx context.Context, groupID, tar var field model.PropertyField if err := s.DBXFromContext(ctx).GetBuilder(&field, builder); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.NewErrNotFound("PropertyField", name) + } return nil, errors.Wrap(err, "property_field_get_by_name_select") } @@ -127,6 +130,24 @@ func (s *SqlPropertyFieldStore) CountForGroup(groupID string, includeDeleted boo return count, nil } +func (s *SqlPropertyFieldStore) CountForGroupObjectType(groupID, objectType string, includeDeleted bool) (int64, error) { + var count int64 + builder := s.getQueryBuilder(). + Select("COUNT(id)"). + From("PropertyFields"). + Where(sq.Eq{"GroupID": groupID}). + Where(sq.Eq{"ObjectType": objectType}) + + if !includeDeleted { + builder = builder.Where(sq.Eq{"DeleteAt": 0}) + } + + if err := s.GetReplica().GetBuilder(&count, builder); err != nil { + return int64(0), errors.Wrap(err, "failed to count property fields for group and object type") + } + return count, nil +} + func (s *SqlPropertyFieldStore) CountForTarget(groupID, targetType, targetID string, includeDeleted bool) (int64, error) { var count int64 builder := s.getQueryBuilder(). @@ -444,8 +465,7 @@ func (s *SqlPropertyFieldStore) buildConflictSubquery(level string, objectType, // new fields. func (s *SqlPropertyFieldStore) CheckPropertyNameConflict(field *model.PropertyField, excludeID string) (model.PropertyFieldTargetLevel, error) { // Legacy properties (PSAv1) use old uniqueness via DB constraint - // FIXME: explicitly excluding templates from the shortcircuit, should be removed after CPA is fully migrated to v2 - if field.IsPSAv1() && field.ObjectType != model.PropertyFieldObjectTypeTemplate { + if field.IsPSAv1() { return "", nil } diff --git a/server/channels/store/sqlstore/property_value_store.go b/server/channels/store/sqlstore/property_value_store.go index 89cac63f0b1..f5a790bb83e 100644 --- a/server/channels/store/sqlstore/property_value_store.go +++ b/server/channels/store/sqlstore/property_value_store.go @@ -4,6 +4,7 @@ package sqlstore import ( + "database/sql" "fmt" sq "github.com/mattermost/squirrel" @@ -105,6 +106,9 @@ func (s *SqlPropertyValueStore) Get(groupID, id string) (*model.PropertyValue, e var value model.PropertyValue if err := s.GetReplica().GetBuilder(&value, builder); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.NewErrNotFound("PropertyValue", id) + } return nil, errors.Wrap(err, "property_value_get_select") } @@ -269,6 +273,11 @@ func (s *SqlPropertyValueStore) Upsert(values []*model.PropertyValue) (_ []*mode updatedValues := make([]*model.PropertyValue, len(values)) updateTime := model.GetMillis() for i, value := range values { + // Pin CreateAt to updateTime so PreSave does not capture a later + // GetMillis() — keeping CreateAt == UpdateAt on insert. + if value.CreateAt == 0 { + value.CreateAt = updateTime + } value.PreSave() value.UpdateAt = updateTime diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 33068beef67..d09b52d8895 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -1149,6 +1149,7 @@ type PropertyFieldStore interface { GetMany(ctx context.Context, groupID string, ids []string) ([]*model.PropertyField, error) GetFieldByName(ctx context.Context, groupID, targetID, name string) (*model.PropertyField, error) CountForGroup(groupID string, includeDeleted bool) (int64, error) + CountForGroupObjectType(groupID, objectType string, includeDeleted bool) (int64, error) CountForTarget(groupID, targetType, targetID string, includeDeleted bool) (int64, error) CountLinkedFields(fieldID string) (int64, error) SearchPropertyFields(opts model.PropertyFieldSearchOpts) ([]*model.PropertyField, error) diff --git a/server/channels/store/storetest/attributes_store.go b/server/channels/store/storetest/attributes_store.go index 7f24b1981b9..b455e89dff2 100644 --- a/server/channels/store/storetest/attributes_store.go +++ b/server/channels/store/storetest/attributes_store.go @@ -99,15 +99,19 @@ func createTestUsers(t *testing.T, rctx request.CTX, ss store.Store) ([]*model.U groupID := group.ID fieldA, err := ss.PropertyField().Create(&model.PropertyField{ - GroupID: groupID, - Name: testPropertyA, - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: testPropertyA, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), }) require.NoError(t, err) fieldB, err := ss.PropertyField().Create(&model.PropertyField{ - GroupID: groupID, - Name: testPropertyB, - Type: model.PropertyFieldTypeText, + GroupID: groupID, + Name: testPropertyB, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), }) require.NoError(t, err) attrs := map[string]any{ @@ -117,10 +121,12 @@ func createTestUsers(t *testing.T, rctx request.CTX, ss store.Store) ([]*model.U }, } fieldC, err := ss.PropertyField().Create(&model.PropertyField{ - GroupID: groupID, - Name: "test_property_c", - Type: model.PropertyFieldTypeSelect, - Attrs: attrs, + GroupID: groupID, + Name: "test_property_c", + Type: model.PropertyFieldTypeSelect, + Attrs: attrs, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), }) require.NoError(t, err) diff --git a/server/channels/store/storetest/mocks/PropertyFieldStore.go b/server/channels/store/storetest/mocks/PropertyFieldStore.go index 996d401df3e..e8f8ca7568f 100644 --- a/server/channels/store/storetest/mocks/PropertyFieldStore.go +++ b/server/channels/store/storetest/mocks/PropertyFieldStore.go @@ -72,6 +72,34 @@ func (_m *PropertyFieldStore) CountForGroup(groupID string, includeDeleted bool) return r0, r1 } +// CountForGroupObjectType provides a mock function with given fields: groupID, objectType, includeDeleted +func (_m *PropertyFieldStore) CountForGroupObjectType(groupID string, objectType string, includeDeleted bool) (int64, error) { + ret := _m.Called(groupID, objectType, includeDeleted) + + if len(ret) == 0 { + panic("no return value specified for CountForGroupObjectType") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(string, string, bool) (int64, error)); ok { + return rf(groupID, objectType, includeDeleted) + } + if rf, ok := ret.Get(0).(func(string, string, bool) int64); ok { + r0 = rf(groupID, objectType, includeDeleted) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(string, string, bool) error); ok { + r1 = rf(groupID, objectType, includeDeleted) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CountForTarget provides a mock function with given fields: groupID, targetType, targetID, includeDeleted func (_m *PropertyFieldStore) CountForTarget(groupID string, targetType string, targetID string, includeDeleted bool) (int64, error) { ret := _m.Called(groupID, targetType, targetID, includeDeleted) diff --git a/server/channels/store/storetest/property_field_store.go b/server/channels/store/storetest/property_field_store.go index 714f5fbfab3..f1fe1cb4368 100644 --- a/server/channels/store/storetest/property_field_store.go +++ b/server/channels/store/storetest/property_field_store.go @@ -5,7 +5,6 @@ package storetest import ( "context" - "database/sql" "fmt" "testing" "time" @@ -342,7 +341,8 @@ func testGetFieldByName(t *testing.T, _ request.CTX, ss store.Store) { t.Run("should fail on nonexisting field", func(t *testing.T) { field, err := ss.PropertyField().GetFieldByName(context.Background(), "", "", "nonexistent-field-name") require.Zero(t, field) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) groupID := model.NewId() @@ -373,13 +373,15 @@ func testGetFieldByName(t *testing.T, _ request.CTX, ss store.Store) { t.Run("should not be able to retrieve an existing field when specifying a different group ID", func(t *testing.T) { field, err := ss.PropertyField().GetFieldByName(context.Background(), model.NewId(), targetID, "unique-field-name") require.Zero(t, field) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) t.Run("should not be able to retrieve an existing field when specifying a different target ID", func(t *testing.T) { field, err := ss.PropertyField().GetFieldByName(context.Background(), groupID, model.NewId(), "unique-field-name") require.Zero(t, field) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) // Test with multiple fields with the same name but different groups @@ -470,7 +472,8 @@ func testGetFieldByName(t *testing.T, _ request.CTX, ss store.Store) { // Verify it can't be retrieved after deletion field, err = ss.PropertyField().GetFieldByName(context.Background(), groupID, targetID, "to-be-deleted-field") require.Zero(t, field) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) t.Run("should not retrieve fields with matching name but different DeleteAt status", func(t *testing.T) { diff --git a/server/channels/store/storetest/property_value_store.go b/server/channels/store/storetest/property_value_store.go index a8378293c98..c0b86f852e5 100644 --- a/server/channels/store/storetest/property_value_store.go +++ b/server/channels/store/storetest/property_value_store.go @@ -4,7 +4,6 @@ package storetest import ( - "database/sql" "encoding/json" "fmt" "testing" @@ -350,7 +349,8 @@ func testGetPropertyValue(t *testing.T, _ request.CTX, ss store.Store, s SqlStor t.Run("should fail on nonexisting value", func(t *testing.T) { value, err := ss.PropertyValue().Get("", model.NewId()) require.Zero(t, value) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) groupID := model.NewId() @@ -381,7 +381,8 @@ func testGetPropertyValue(t *testing.T, _ request.CTX, ss store.Store, s SqlStor t.Run("should not be able to retrieve an existing value when specifying a different group ID", func(t *testing.T) { value, err := ss.PropertyValue().Get(model.NewId(), newValue.ID) require.Zero(t, value) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) t.Run("should be able to retrieve an existing property value with matching groupID", func(t *testing.T) { @@ -418,7 +419,8 @@ func testGetPropertyValue(t *testing.T, _ request.CTX, ss store.Store, s SqlStor // Try to get the value with a different group ID value, err := ss.PropertyValue().Get(model.NewId(), newValue.ID) require.Zero(t, value) - require.ErrorIs(t, err, sql.ErrNoRows) + var enf *store.ErrNotFound + require.ErrorAs(t, err, &enf) }) t.Run("null columns, before createdBy and updatedBy migrations", func(t *testing.T) { diff --git a/server/channels/testlib/store.go b/server/channels/testlib/store.go index 261eac37136..9aee6d57e96 100644 --- a/server/channels/testlib/store.go +++ b/server/channels/testlib/store.go @@ -146,11 +146,13 @@ func GetMockStoreForSetupFunctions() *mocks.Store { groupsByName := map[string]*model.PropertyGroup{} - cpaGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.CustomProfileAttributesPropertyGroupName, Version: model.PropertyGroupVersionV1} + accessControlGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.AccessControlPropertyGroupName, Version: model.PropertyGroupVersionV2} + contentFlaggingGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.ContentFlaggingGroupName, Version: model.PropertyGroupVersionV1} managedCategoryGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.ManagedCategoryPropertyGroupName, Version: model.PropertyGroupVersionV2} boardsGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.BoardsPropertyGroupName, Version: model.PropertyGroupVersionV2} - groupsByName[cpaGroup.Name] = cpaGroup + groupsByName[accessControlGroup.Name] = accessControlGroup + groupsByName[contentFlaggingGroup.Name] = contentFlaggingGroup groupsByName[managedCategoryGroup.Name] = managedCategoryGroup groupsByName[boardsGroup.Name] = boardsGroup @@ -177,7 +179,8 @@ func GetMockStoreForSetupFunctions() *mocks.Store { return nil }, ) - propertyGroupStore.On("Get", model.CustomProfileAttributesPropertyGroupName).Return(cpaGroup, nil) + propertyGroupStore.On("Get", model.AccessControlPropertyGroupName).Return(accessControlGroup, nil) + propertyGroupStore.On("Get", model.ContentFlaggingGroupName).Return(contentFlaggingGroup, nil) propertyGroupStore.On("Get", model.ManagedCategoryPropertyGroupName).Return(managedCategoryGroup, nil) propertyGroupStore.On("Get", model.BoardsPropertyGroupName).Return(boardsGroup, nil) diff --git a/server/cmd/mmctl/commands/user_attributes_field_e2e_test.go b/server/cmd/mmctl/commands/user_attributes_field_e2e_test.go index 53fbfd3ceb3..77e5cc25bbc 100644 --- a/server/cmd/mmctl/commands/user_attributes_field_e2e_test.go +++ b/server/cmd/mmctl/commands/user_attributes_field_e2e_test.go @@ -4,6 +4,8 @@ package commands import ( + "context" + "github.com/mattermost/mattermost/server/public/model" "github.com/spf13/cobra" @@ -11,13 +13,52 @@ import ( "github.com/mattermost/mattermost/server/v8/cmd/mmctl/printer" ) +// createCPAField posts the given CPAField via the admin HTTP client and +// returns the server response reshaped as a typed CPAField. +func (s *MmctlE2ETestSuite) createCPAField(field *model.CPAField) *model.CPAField { + s.T().Helper() + created, _, err := s.th.SystemAdminClient.CreateCPAField(context.Background(), field.ToPropertyField()) + s.Require().NoError(err) + cpa, err := model.NewCPAFieldFromPropertyField(created) + s.Require().NoError(err) + return cpa +} + +// listCPAFields fetches all CPA fields via the admin HTTP client, returning +// them as typed CPAFields. +func (s *MmctlE2ETestSuite) listCPAFields() []*model.CPAField { + s.T().Helper() + fields, _, err := s.th.SystemAdminClient.ListCPAFields(context.Background()) + s.Require().NoError(err) + out := make([]*model.CPAField, 0, len(fields)) + for _, pf := range fields { + cpa, err := model.NewCPAFieldFromPropertyField(pf) + s.Require().NoError(err) + out = append(out, cpa) + } + return out +} + +// getCPAField fetches a single CPA field by ID. There is no single-field HTTP +// endpoint, so this filters the full list — sufficient for verifying updates +// in tests with a clean fixture state. +func (s *MmctlE2ETestSuite) getCPAField(id string) *model.CPAField { + s.T().Helper() + for _, f := range s.listCPAFields() { + if f.ID == id { + return f + } + } + s.T().Fatalf("CPA field %q not found", id) + return nil +} + // cleanCPAFields removes all existing CPA fields to ensure clean test state func (s *MmctlE2ETestSuite) cleanCPAFields() { - existingFields, appErr := s.th.App.ListCPAFields(nil) - s.Require().Nil(appErr) - for _, field := range existingFields { - appErr := s.th.App.DeleteCPAField(nil, field.ID) - s.Require().Nil(appErr) + s.T().Helper() + for _, field := range s.listCPAFields() { + _, err := s.th.SystemAdminClient.DeleteCPAField(context.Background(), field.ID) + s.Require().NoError(err) } } @@ -66,12 +107,10 @@ func (s *MmctlE2ETestSuite) TestCPAFieldListCmd() { }, } - createdTextField, appErr := s.th.App.CreateCPAField(nil, textField) - s.Require().Nil(appErr) + createdTextField := s.createCPAField(textField) s.Require().NotNil(createdTextField) - createdSelectField, appErr := s.th.App.CreateCPAField(nil, selectField) - s.Require().Nil(appErr) + createdSelectField := s.createCPAField(selectField) s.Require().NotNil(createdSelectField) // Now test the list command @@ -114,8 +153,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldCreateCmd() { s.Require().Contains(output, "Field Department correctly created") // Verify field was actually created in the database - fields, appErr := s.th.App.ListCPAFields(nil) - s.Require().Nil(appErr) + fields := s.listCPAFields() s.Require().Len(fields, 1) s.Require().Equal("Department", fields[0].Name) s.Require().Equal(model.PropertyFieldTypeText, fields[0].Type) @@ -150,8 +188,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldCreateCmd() { s.Require().Contains(output, "Field Skills correctly created") // Verify field was actually created in the database with correct options - fields, appErr := s.th.App.ListCPAFields(nil) - s.Require().Nil(appErr) + fields := s.listCPAFields() s.Require().Len(fields, 1) s.Require().Equal("Skills", fields[0].Name) s.Require().Equal(model.PropertyFieldTypeMultiselect, fields[0].Type) @@ -210,8 +247,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) // Now edit the field cmd := &cobra.Command{} @@ -237,8 +273,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { s.Require().Contains(output, "Field Programming Languages successfully updated") // Verify field was actually updated - updatedField, appErr := s.th.App.GetCPAField(nil, createdField.ID) - s.Require().Nil(appErr) + updatedField := s.getCPAField(createdField.ID) s.Require().Equal("Programming Languages", updatedField.Name) // Check options @@ -268,8 +303,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) // Now edit the field with --managed flag cmd := &cobra.Command{} @@ -287,8 +321,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { s.Require().Len(printer.GetErrorLines(), 0) // Verify field was actually updated - updatedField, appErr := s.th.App.GetCPAField(nil, createdField.ID) - s.Require().Nil(appErr) + updatedField := s.getCPAField(createdField.ID) // Verify that managed flag was set correctly s.Require().Equal("admin", updatedField.Attrs.Managed) @@ -310,8 +343,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) // Now edit the field using its name instead of ID cmd := &cobra.Command{} @@ -336,8 +368,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { s.Require().Contains(output, "Field Team successfully updated") // Verify field was actually updated by retrieving it - updatedField, appErr := s.th.App.GetCPAField(nil, createdField.ID) - s.Require().Nil(appErr) + updatedField := s.getCPAField(createdField.ID) s.Require().Equal("Team", updatedField.Name) // Check managed status @@ -363,8 +394,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) // Get the original option IDs to verify they are preserved s.Require().Len(createdField.Attrs.Options, 2) @@ -406,8 +436,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldEditCmd() { s.Require().Contains(output, "Field Programming Languages successfully updated") // Verify field was actually updated and options are preserved correctly - updatedField, appErr := s.th.App.GetCPAField(nil, createdField.ID) - s.Require().Nil(appErr) + updatedField := s.getCPAField(createdField.ID) // Check options s.Require().Len(updatedField.Attrs.Options, 3) @@ -456,8 +485,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldDeleteCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) cmd := &cobra.Command{} cmd.Flags().Bool("confirm", false, "") @@ -475,8 +503,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldDeleteCmd() { s.Require().Contains(output, "Successfully deleted CPA field") // Verify field was actually deleted by checking if it exists in the list - fields, appErr := s.th.App.ListCPAFields(nil) - s.Require().Nil(appErr) + fields := s.listCPAFields() // Field should not be in the list anymore fieldExists := false @@ -502,8 +529,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldDeleteCmd() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, field) - s.Require().Nil(appErr) + createdField := s.createCPAField(field) cmd := &cobra.Command{} cmd.Flags().Bool("confirm", false, "") @@ -522,8 +548,7 @@ func (s *MmctlE2ETestSuite) TestCPAFieldDeleteCmd() { s.Require().Contains(output, "Successfully deleted CPA field: Department") // Verify field was actually deleted by checking if it exists in the list - fields, appErr := s.th.App.ListCPAFields(nil) - s.Require().Nil(appErr) + fields := s.listCPAFields() // Field should not be in the list anymore fieldExists := false diff --git a/server/cmd/mmctl/commands/user_attributes_value_e2e_test.go b/server/cmd/mmctl/commands/user_attributes_value_e2e_test.go index 4370d4af4fe..79f0cce2dbc 100644 --- a/server/cmd/mmctl/commands/user_attributes_value_e2e_test.go +++ b/server/cmd/mmctl/commands/user_attributes_value_e2e_test.go @@ -4,6 +4,7 @@ package commands import ( + "context" "encoding/json" "github.com/mattermost/mattermost/server/public/model" @@ -12,21 +13,31 @@ import ( "github.com/mattermost/mattermost/server/v8/cmd/mmctl/printer" ) +// listCPAValuesForUser fetches the user's CPA values via the admin HTTP +// client (field-id → raw-JSON map, same shape the command returns). +func (s *MmctlE2ETestSuite) listCPAValuesForUser(userID string) map[string]json.RawMessage { + s.T().Helper() + values, _, err := s.th.SystemAdminClient.ListCPAValues(context.Background(), userID) + s.Require().NoError(err) + return values +} + // cleanCPAValuesForUser removes all CPA values for a user func (s *MmctlE2ETestSuite) cleanCPAValuesForUser(userID string) { - existingValues, appErr := s.th.App.ListCPAValues(nil, userID) - s.Require().Nil(appErr) + s.T().Helper() + existing := s.listCPAValuesForUser(userID) + if len(existing) == 0 { + return + } // Clear all existing values by setting them to null - updates := make(map[string]json.RawMessage) - for _, value := range existingValues { - updates[value.FieldID] = json.RawMessage("null") + updates := make(map[string]json.RawMessage, len(existing)) + for fieldID := range existing { + updates[fieldID] = json.RawMessage("null") } - if len(updates) > 0 { - _, appErr = s.th.App.PatchCPAValues(nil, userID, updates, false) - s.Require().Nil(appErr) - } + _, _, err := s.th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), userID, updates) + s.Require().NoError(err) } func (s *MmctlE2ETestSuite) TestCPAValueList() { @@ -64,19 +75,18 @@ func (s *MmctlE2ETestSuite) TestCPAValueList() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, textField) - s.Require().Nil(appErr) + createdField := s.createCPAField(textField) - // Set a text value using the app layer + // Seed a text value via the admin HTTP client. updates := map[string]json.RawMessage{ createdField.ID: json.RawMessage(`"Engineering"`), } - _, appErr = s.th.App.PatchCPAValues(nil, s.th.BasicUser.Id, updates, false) - s.Require().Nil(appErr) + _, _, err := s.th.SystemAdminClient.PatchCPAValuesForUser(context.Background(), s.th.BasicUser.Id, updates) + s.Require().NoError(err) // Test listing the values with plain format (human-readable) printer.SetFormat(printer.FormatPlain) - err := cpaValueListCmdF(c, &cobra.Command{}, []string{s.th.BasicUser.Email}) + err = cpaValueListCmdF(c, &cobra.Command{}, []string{s.th.BasicUser.Email}) s.Require().Nil(err) s.Require().Len(printer.GetLines(), 1) s.Require().Len(printer.GetErrorLines(), 0) @@ -122,8 +132,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, textField) - s.Require().Nil(appErr) + createdField := s.createCPAField(textField) // Set a text value cmd := &cobra.Command{} @@ -136,11 +145,9 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // Verify the value was set - values, appErr := s.th.App.ListCPAValues(nil, s.th.BasicUser.Id) - s.Require().Nil(appErr) + values := s.listCPAValuesForUser(s.th.BasicUser.Id) s.Require().Len(values, 1) - s.Require().Equal(createdField.ID, values[0].FieldID) - s.Require().Equal(`"Engineering"`, string(values[0].Value)) + s.Require().Equal(`"Engineering"`, string(values[createdField.ID])) }) s.Run("Set value for select type field", func() { @@ -166,8 +173,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, selectField) - s.Require().Nil(appErr) + createdField := s.createCPAField(selectField) // Set a select value using the option name cmd := &cobra.Command{} @@ -180,10 +186,8 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // Verify the value was set (should be stored as option ID) - values, appErr := s.th.App.ListCPAValues(nil, s.th.BasicUser.Id) - s.Require().Nil(appErr) + values := s.listCPAValuesForUser(s.th.BasicUser.Id) s.Require().Len(values, 1) - s.Require().Equal(createdField.ID, values[0].FieldID) // Find the Senior option ID for verification var seniorOptionID string @@ -193,7 +197,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { break } } - s.Require().Equal(`"`+seniorOptionID+`"`, string(values[0].Value)) + s.Require().Equal(`"`+seniorOptionID+`"`, string(values[createdField.ID])) }) s.Run("Set value for multiselect type field", func() { @@ -220,8 +224,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, multiselectField) - s.Require().Nil(appErr) + createdField := s.createCPAField(multiselectField) // Set multiple values using option names cmd := &cobra.Command{} @@ -239,10 +242,8 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // Verify the values were set (should be stored as option IDs) - values, appErr := s.th.App.ListCPAValues(nil, s.th.BasicUser.Id) - s.Require().Nil(appErr) + values := s.listCPAValuesForUser(s.th.BasicUser.Id) s.Require().Len(values, 1) - s.Require().Equal(createdField.ID, values[0].FieldID) // Find the option IDs for verification var goOptionID, reactOptionID, pythonOptionID string @@ -259,7 +260,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // The multiselect values should be stored as an array of option IDs // The JSON serialization may include spaces, so we need to compare the content, not exact string - actualValue := string(values[0].Value) + actualValue := string(values[createdField.ID]) s.Require().Contains(actualValue, goOptionID) s.Require().Contains(actualValue, reactOptionID) s.Require().Contains(actualValue, pythonOptionID) @@ -288,8 +289,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, multiselectField) - s.Require().Nil(appErr) + createdField := s.createCPAField(multiselectField) // Set a single value using option name cmd := &cobra.Command{} @@ -303,10 +303,8 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // Verify the value was set (should be stored as an array with single option ID) - values, appErr := s.th.App.ListCPAValues(nil, s.th.BasicUser.Id) - s.Require().Nil(appErr) + values := s.listCPAValuesForUser(s.th.BasicUser.Id) s.Require().Len(values, 1) - s.Require().Equal(createdField.ID, values[0].FieldID) // Find the option ID for verification var pythonOptionID string @@ -319,7 +317,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // The multiselect value should be stored as an array with single option ID // Even for single value, multiselect fields store values as arrays - actualValue := string(values[0].Value) + actualValue := string(values[createdField.ID]) s.Require().Contains(actualValue, pythonOptionID) s.Require().Contains(actualValue, "[") s.Require().Contains(actualValue, "]") @@ -349,8 +347,7 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { }, } - createdField, appErr := s.th.App.CreateCPAField(nil, userField) - s.Require().Nil(appErr) + createdField := s.createCPAField(userField) // Set a user value using the system admin user ID cmd := &cobra.Command{} @@ -363,10 +360,8 @@ func (s *MmctlE2ETestSuite) TestCPAValueSet() { // Verify the value was set - values, appErr := s.th.App.ListCPAValues(nil, s.th.BasicUser.Id) - s.Require().Nil(appErr) + values := s.listCPAValuesForUser(s.th.BasicUser.Id) s.Require().Len(values, 1) - s.Require().Equal(createdField.ID, values[0].FieldID) - s.Require().Equal(`"`+s.th.SystemAdminUser.Id+`"`, string(values[0].Value)) + s.Require().Equal(`"`+s.th.SystemAdminUser.Id+`"`, string(values[createdField.ID])) }) } diff --git a/server/i18n/en.json b/server/i18n/en.json index 05abb45be52..e3cf0f0c27c 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -2045,10 +2045,6 @@ "id": "api.custom_profile_attributes.invalid_field_patch", "translation": "invalid User Attribute field patch" }, - { - "id": "api.custom_profile_attributes.license_error", - "translation": "Your license does not support User Attributes." - }, { "id": "api.custom_status.disabled", "translation": "Custom status feature has been disabled. Please contact your system administrator for details." @@ -3182,10 +3178,6 @@ "id": "api.property_field.delete.no_permission.app_error", "translation": "You do not have permission to delete this property field." }, - { - "id": "api.property_field.delete.protected_via_api.app_error", - "translation": "Cannot delete a protected property field via API." - }, { "id": "api.property_field.get.invalid_target_type.app_error", "translation": "A valid target_type (system, team, or channel) is required." @@ -3202,26 +3194,10 @@ "id": "api.property_field.object_type_mismatch.app_error", "translation": "Property field object type does not match URL." }, - { - "id": "api.property_field.patch.cannot_link_existing.app_error", - "translation": "Cannot set linked_field_id on an existing field. It can only be set at creation time." - }, { "id": "api.property_field.patch.legacy_field.app_error", "translation": "Cannot patch a v1 property field via this API." }, - { - "id": "api.property_field.patch.linked_field_change.app_error", - "translation": "Cannot change link target. Unlink first, then create a new linked field." - }, - { - "id": "api.property_field.patch.linked_options_change.app_error", - "translation": "Cannot modify options of a linked field. Options are inherited from the source." - }, - { - "id": "api.property_field.patch.linked_type_change.app_error", - "translation": "Cannot modify type of a linked field. Type is inherited from the source." - }, { "id": "api.property_field.update.no_field_permission.app_error", "translation": "You do not have permission to edit this property field." @@ -3230,14 +3206,6 @@ "id": "api.property_field.update.no_options_permission.app_error", "translation": "You do not have permission to manage options for this property field." }, - { - "id": "api.property_field.update.protected_via_api.app_error", - "translation": "Cannot update a protected property field via API." - }, - { - "id": "api.property_value.field_object_type_mismatch.app_error", - "translation": "One or more property fields do not match the route's object type." - }, { "id": "api.property_value.invalid_object_type.app_error", "translation": "The provided object type is not valid." @@ -3250,6 +3218,10 @@ "id": "api.property_value.patch.empty_body.app_error", "translation": "Request body must contain at least one property value update." }, + { + "id": "api.property_value.patch.field_not_found.app_error", + "translation": "Property field {{.FieldID}} was not found in this group." + }, { "id": "api.property_value.patch.invalid_field_id.app_error", "translation": "One or more field IDs in the request are invalid." @@ -3266,10 +3238,6 @@ "id": "api.property_value.system_use_dedicated_route.app_error", "translation": "System values must use the dedicated system values endpoint." }, - { - "id": "api.property_value.target_user.forbidden.app_error", - "translation": "You do not have permission to access property values for another user." - }, { "id": "api.property_value.template_no_values.app_error", "translation": "Template fields cannot have values." @@ -5906,82 +5874,10 @@ "id": "app.custom_group.unique_name", "translation": "group name is not unique" }, - { - "id": "app.custom_profile_attributes.count_property_fields.app_error", - "translation": "Unable to count the number of fields for the User Attributes group" - }, - { - "id": "app.custom_profile_attributes.cpa_group_id.app_error", - "translation": "Unable to retrieve the User Attributes property group." - }, - { - "id": "app.custom_profile_attributes.delete_property_values_for_user.app_error", - "translation": "Unable to delete User Attribute values for user" - }, - { - "id": "app.custom_profile_attributes.get_property_field.app_error", - "translation": "Unable to get User Attribute field" - }, - { - "id": "app.custom_profile_attributes.get_property_value.app_error", - "translation": "Unable to get User Attribute value" - }, - { - "id": "app.custom_profile_attributes.limit_reached.app_error", - "translation": "User Attributes field limit reached" - }, - { - "id": "app.custom_profile_attributes.list_property_values.app_error", - "translation": "Unable to get User Attribute values" - }, - { - "id": "app.custom_profile_attributes.patch_field.app_error", - "translation": "Unable to patch User Attribute field" - }, { "id": "app.custom_profile_attributes.property_field_conversion.app_error", "translation": "Unable to convert the property field to a User Attribute field" }, - { - "id": "app.custom_profile_attributes.property_field_delete.app_error", - "translation": "Unable to delete User Attribute field" - }, - { - "id": "app.custom_profile_attributes.property_field_is_managed.app_error", - "translation": "Cannot update value for an admin-managed User Attribute field" - }, - { - "id": "app.custom_profile_attributes.property_field_is_synced.app_error", - "translation": "Cannot update value for a synced User Attribute field" - }, - { - "id": "app.custom_profile_attributes.property_field_not_found.app_error", - "translation": "User Attribute field not found" - }, - { - "id": "app.custom_profile_attributes.property_field_update.app_error", - "translation": "Unable to update User Attribute field" - }, - { - "id": "app.custom_profile_attributes.property_value_upsert.app_error", - "translation": "Unable to upsert User Attribute fields" - }, - { - "id": "app.custom_profile_attributes.sanitize_and_validate.app_error", - "translation": "Invalid property value attributes : {{.AttributeName}} ({{.Reason}})." - }, - { - "id": "app.custom_profile_attributes.sanitize_and_validate.display_name_too_long.app_error", - "translation": "CPA field display_name exceeds the maximum length of {{.MaxRunes}} characters." - }, - { - "id": "app.custom_profile_attributes.search_property_fields.app_error", - "translation": "Unable to search User Attribute fields" - }, - { - "id": "app.custom_profile_attributes.validate_value.app_error", - "translation": "Failed to validate property value" - }, { "id": "app.data_spillage.assign_reviewer.no_reviewer_field.app_error", "translation": "No Reviewer ID property field found." @@ -8112,6 +8008,26 @@ "id": "app.prepackged-plugin.invalid_version.app_error", "translation": "Prepackged plugin version could not be parsed." }, + { + "id": "app.property.access_denied.app_error", + "translation": "You do not have permission to perform this operation." + }, + { + "id": "app.property.invalid_access_mode.app_error", + "translation": "The access_mode attribute is invalid." + }, + { + "id": "app.property.license_error", + "translation": "Your license does not support this property group." + }, + { + "id": "app.property.not_found.app_error", + "translation": "The specified property does not exist." + }, + { + "id": "app.property.sync_lock.app_error", + "translation": "This property field is managed by external sync and cannot be modified directly." + }, { "id": "app.property_field.count_for_group.app_error", "translation": "Unable to count property fields for group." @@ -8124,6 +8040,14 @@ "id": "app.property_field.create.app_error", "translation": "Unable to create property field." }, + { + "id": "app.property_field.create.group_limit_reached.app_error", + "translation": "The maximum number of property fields for this group has been reached." + }, + { + "id": "app.property_field.create.limit_reached.app_error", + "translation": "The maximum number of property fields for this object type has been reached." + }, { "id": "app.property_field.create.linked_source_cross_group.app_error", "translation": "Cannot link to a field in a different group." @@ -8197,13 +8121,21 @@ "translation": "Unable to get property fields." }, { - "id": "app.property_field.get_many.fields_not_found.app_error", - "translation": "One or more property field IDs were not found in the specified group." + "id": "app.property_field.invalid_attrs.app_error", + "translation": "Invalid property field attributes." }, { "id": "app.property_field.invalid_input.app_error", "translation": "Invalid input provided." }, + { + "id": "app.property_field.managed_admin.permission.app_error", + "translation": "You do not have permission to mark this property field as admin-managed." + }, + { + "id": "app.property_field.not_found.app_error", + "translation": "The specified property field does not exist." + }, { "id": "app.property_field.search.app_error", "translation": "Unable to search property fields." @@ -8316,10 +8248,34 @@ "id": "app.property_value.upsert.app_error", "translation": "Unable to upsert property values." }, + { + "id": "app.property_value.upsert.duplicate_field_id.app_error", + "translation": "Duplicate field ID in property value batch." + }, + { + "id": "app.property_value.upsert.field_not_found.app_error", + "translation": "Property field {{.FieldID}} was not found." + }, + { + "id": "app.property_value.upsert.invalid_field_id.app_error", + "translation": "Invalid property field ID." + }, + { + "id": "app.property_value.upsert.mixed_groups.app_error", + "translation": "All property values in a batch must belong to the same property group." + }, + { + "id": "app.property_value.upsert.object_type_mismatch.app_error", + "translation": "Property field object type does not match the request." + }, { "id": "app.property_value.upsert_many.app_error", "translation": "Unable to upsert property values." }, + { + "id": "app.property_value.validate.app_error", + "translation": "Property value failed validation." + }, { "id": "app.reaction.bulk_get_for_post_ids.app_error", "translation": "Unable to get reactions for post." diff --git a/server/public/model/custom_profile_attributes.go b/server/public/model/custom_profile_attributes.go index 9db88a2bdad..1f615f00f02 100644 --- a/server/public/model/custom_profile_attributes.go +++ b/server/public/model/custom_profile_attributes.go @@ -14,33 +14,33 @@ import ( "errors" "fmt" "net/http" - "net/url" "regexp" - "strings" - "unicode/utf8" + "sort" ) -const CustomProfileAttributesPropertyGroupName = "custom_profile_attributes" - +// CPA-prefixed aliases for the canonical PropertyField* constants in +// property_field_attrs_validation.go. Aliasing (not redeclaring) keeps CPA +// writes and property-hook reads keyed on the same string at compile time, +// so a rename to one side cannot silently diverge from the other. const ( // Attributes keys - CustomProfileAttributesPropertyAttrsSortOrder = "sort_order" - CustomProfileAttributesPropertyAttrsValueType = "value_type" - CustomProfileAttributesPropertyAttrsVisibility = "visibility" - CustomProfileAttributesPropertyAttrsLDAP = "ldap" - CustomProfileAttributesPropertyAttrsSAML = "saml" - CustomProfileAttributesPropertyAttrsManaged = "managed" - CustomProfileAttributesPropertyAttrsDisplayName = "display_name" + CustomProfileAttributesPropertyAttrsSortOrder = PropertyFieldAttrSortOrder + CustomProfileAttributesPropertyAttrsValueType = PropertyFieldAttrValueType + CustomProfileAttributesPropertyAttrsVisibility = PropertyFieldAttrVisibility + CustomProfileAttributesPropertyAttrsLDAP = PropertyFieldAttrLDAP + CustomProfileAttributesPropertyAttrsSAML = PropertyFieldAttrSAML + CustomProfileAttributesPropertyAttrsManaged = PropertyFieldAttrManaged + CustomProfileAttributesPropertyAttrsDisplayName = PropertyFieldAttrDisplayName // Value Types - CustomProfileAttributesValueTypeEmail = "email" - CustomProfileAttributesValueTypeURL = "url" - CustomProfileAttributesValueTypePhone = "phone" + CustomProfileAttributesValueTypeEmail = PropertyFieldValueTypeEmail + CustomProfileAttributesValueTypeURL = PropertyFieldValueTypeURL + CustomProfileAttributesValueTypePhone = PropertyFieldValueTypePhone // Visibility - CustomProfileAttributesVisibilityHidden = "hidden" - CustomProfileAttributesVisibilityWhenSet = "when_set" - CustomProfileAttributesVisibilityAlways = "always" + CustomProfileAttributesVisibilityHidden = PropertyFieldVisibilityHidden + CustomProfileAttributesVisibilityWhenSet = PropertyFieldVisibilityWhenSet + CustomProfileAttributesVisibilityAlways = PropertyFieldVisibilityAlways CustomProfileAttributesVisibilityDefault = CustomProfileAttributesVisibilityWhenSet // CPA options @@ -48,31 +48,9 @@ const ( CPAOptionColorMaxLength = 128 // CPA value constraints - CPAValueTypeTextMaxLength = 64 + CPAValueTypeTextMaxLength = PropertyFieldValueTypeTextMaxLength ) -func IsKnownCPAValueType(valueType string) bool { - switch valueType { - case CustomProfileAttributesValueTypeEmail, - CustomProfileAttributesValueTypeURL, - CustomProfileAttributesValueTypePhone: - return true - } - - return false -} - -func IsKnownCPAVisibility(visibility string) bool { - switch visibility { - case CustomProfileAttributesVisibilityHidden, - CustomProfileAttributesVisibilityWhenSet, - CustomProfileAttributesVisibilityAlways: - return true - } - - return false -} - // CPAFieldNamePattern defines the character set allowed for CPA field names. // Matches the CEL IDENTIFIER grammar (^[A-Za-z_][A-Za-z0-9_]*$) used by the // ABAC engine (cel-go v0.27.0). Leading underscore is permitted — this is consistent @@ -200,13 +178,6 @@ func (c *CPAField) IsAdminManaged() bool { return c.Attrs.Managed == "admin" } -// SetDefaults sets default values for CPAField attributes -func (c *CPAField) SetDefaults() { - if c.Attrs.Visibility == "" { - c.Attrs.Visibility = CustomProfileAttributesVisibilityDefault - } -} - // Patch applies a PropertyFieldPatch to the CPAField by converting to PropertyField, // applying the patch, and converting back. This ensures we only maintain one patch logic path. // Custom profile attributes doesn't use targets, so TargetID and TargetType are cleared. @@ -253,101 +224,6 @@ func (c *CPAField) ToPropertyField() *PropertyField { return &pf } -// SupportsOptions checks the CPAField type and determines if the type -// supports the use of options -func (c *CPAField) SupportsOptions() bool { - return c.Type == PropertyFieldTypeSelect || c.Type == PropertyFieldTypeMultiselect -} - -// SupportsSyncing checks the CPAField type and determines if it -// supports syncing with external sources of truth -func (c *CPAField) SupportsSyncing() bool { - return c.Type == PropertyFieldTypeText -} - -func (c *CPAField) SanitizeAndValidate() *AppError { - c.SetDefaults() - - // first we clean unused attributes depending on the field type - if !c.SupportsOptions() { - c.Attrs.Options = nil - } - if !c.SupportsSyncing() { - c.Attrs.LDAP = "" - c.Attrs.SAML = "" - } - - // Clear sync properties if managed is set (mutual exclusivity) - if c.IsAdminManaged() { - c.Attrs.LDAP = "" - c.Attrs.SAML = "" - } - - switch c.Type { - case PropertyFieldTypeText: - if valueType := strings.TrimSpace(c.Attrs.ValueType); valueType != "" { - if !IsKnownCPAValueType(valueType) { - return NewAppError("SanitizeAndValidate", "app.custom_profile_attributes.sanitize_and_validate.app_error", map[string]any{ - "AttributeName": CustomProfileAttributesPropertyAttrsValueType, - "Reason": "unknown value type", - }, "", http.StatusUnprocessableEntity) - } - c.Attrs.ValueType = valueType - } - - case PropertyFieldTypeSelect, PropertyFieldTypeMultiselect: - options := c.Attrs.Options - - // add an ID to options with no ID - for i := range options { - if options[i].ID == "" { - options[i].ID = NewId() - } - } - - if err := options.IsValid(); err != nil { - return NewAppError("SanitizeAndValidate", "app.custom_profile_attributes.sanitize_and_validate.app_error", map[string]any{ - "AttributeName": PropertyFieldAttributeOptions, - "Reason": err.Error(), - }, "", http.StatusUnprocessableEntity).Wrap(err) - } - c.Attrs.Options = options - } - - // Validate visibility - if visibilityAttr := strings.TrimSpace(c.Attrs.Visibility); visibilityAttr != "" { - if !IsKnownCPAVisibility(visibilityAttr) { - return NewAppError("SanitizeAndValidate", "app.custom_profile_attributes.sanitize_and_validate.app_error", map[string]any{ - "AttributeName": CustomProfileAttributesPropertyAttrsVisibility, - "Reason": "unknown visibility", - }, "", http.StatusUnprocessableEntity) - } - c.Attrs.Visibility = visibilityAttr - } - - // Validate managed field - if managed := strings.TrimSpace(c.Attrs.Managed); managed != "" { - if managed != "admin" { - return NewAppError("SanitizeAndValidate", "app.custom_profile_attributes.sanitize_and_validate.app_error", map[string]any{ - "AttributeName": CustomProfileAttributesPropertyAttrsManaged, - "Reason": "unknown managed type", - }, "", http.StatusBadRequest) - } - c.Attrs.Managed = managed - } - - // Sanitize and validate display_name - // Reuses PropertyFieldNameMaxRunes to keep the DisplayName cap aligned with the Name cap; do NOT introduce a separate constant. - c.Attrs.DisplayName = strings.TrimSpace(c.Attrs.DisplayName) - if utf8.RuneCountInString(c.Attrs.DisplayName) > PropertyFieldNameMaxRunes { - return NewAppError("SanitizeAndValidate", "app.custom_profile_attributes.sanitize_and_validate.display_name_too_long.app_error", map[string]any{ - "MaxRunes": PropertyFieldNameMaxRunes, - }, "", http.StatusUnprocessableEntity) - } - - return nil -} - func NewCPAFieldFromPropertyField(pf *PropertyField) (*CPAField, error) { attrsJSON, err := json.Marshal(pf.Attrs) if err != nil { @@ -365,83 +241,27 @@ func NewCPAFieldFromPropertyField(pf *PropertyField) (*CPAField, error) { Attrs: attrs, } - cpaField.SetDefaults() - return cpaField, nil } -// SanitizeAndValidatePropertyValue validates and sanitizes the given -// property value based on the field type -func SanitizeAndValidatePropertyValue(cpaField *CPAField, rawValue json.RawMessage) (json.RawMessage, error) { - fieldType := cpaField.Type - - // build a list of existing options so we can check later if the values exist - optionsMap := map[string]struct{}{} - for _, v := range cpaField.Attrs.Options { - optionsMap[v.ID] = struct{}{} - } - - switch fieldType { - case PropertyFieldTypeText, PropertyFieldTypeDate, PropertyFieldTypeSelect, PropertyFieldTypeUser: - var value string - if err := json.Unmarshal(rawValue, &value); err != nil { +// CPAFieldsFromPropertyFields converts a slice of PropertyFields to CPAFields +// and sorts the result by Attrs.SortOrder ascending. +func CPAFieldsFromPropertyFields(pfs []*PropertyField) ([]*CPAField, error) { + cpaFields := make([]*CPAField, 0, len(pfs)) + for _, pf := range pfs { + cpaField, err := NewCPAFieldFromPropertyField(pf) + if err != nil { return nil, err } - value = strings.TrimSpace(value) - - if fieldType == PropertyFieldTypeText { - if len(value) > CPAValueTypeTextMaxLength { - return nil, fmt.Errorf("value too long") - } - - if cpaField.Attrs.ValueType == CustomProfileAttributesValueTypeEmail && !IsValidEmail(value) { - return nil, fmt.Errorf("invalid email") - } - - if cpaField.Attrs.ValueType == CustomProfileAttributesValueTypeURL { - _, err := url.Parse(value) - if err != nil { - return nil, fmt.Errorf("invalid url: %w", err) - } - } - } - - if fieldType == PropertyFieldTypeSelect && value != "" { - if _, ok := optionsMap[value]; !ok { - return nil, fmt.Errorf("option \"%s\" does not exist", value) - } - } - - if fieldType == PropertyFieldTypeUser && value != "" && !IsValidId(value) { - return nil, fmt.Errorf("invalid user id") - } - return json.Marshal(value) - - case PropertyFieldTypeMultiselect, PropertyFieldTypeMultiuser: - var values []string - if err := json.Unmarshal(rawValue, &values); err != nil { - return nil, err - } - filteredValues := make([]string, 0, len(values)) - for _, v := range values { - trimmed := strings.TrimSpace(v) - if trimmed == "" { - continue - } - if fieldType == PropertyFieldTypeMultiselect { - if _, ok := optionsMap[v]; !ok { - return nil, fmt.Errorf("option \"%s\" does not exist", v) - } - } - - if fieldType == PropertyFieldTypeMultiuser && !IsValidId(trimmed) { - return nil, fmt.Errorf("invalid user id: %s", trimmed) - } - filteredValues = append(filteredValues, trimmed) - } - return json.Marshal(filteredValues) - - default: - return nil, fmt.Errorf("unknown field type: %s", fieldType) + cpaFields = append(cpaFields, cpaField) } + + sort.Slice(cpaFields, func(i, j int) bool { + if cpaFields[i].Attrs.SortOrder != cpaFields[j].Attrs.SortOrder { + return cpaFields[i].Attrs.SortOrder < cpaFields[j].Attrs.SortOrder + } + return cpaFields[i].ID < cpaFields[j].ID + }) + + return cpaFields, nil } diff --git a/server/public/model/custom_profile_attributes_test.go b/server/public/model/custom_profile_attributes_test.go index 4015c08d28f..85d651fd62e 100644 --- a/server/public/model/custom_profile_attributes_test.go +++ b/server/public/model/custom_profile_attributes_test.go @@ -4,7 +4,6 @@ package model import ( - "encoding/json" "fmt" "strings" "testing" @@ -24,7 +23,7 @@ func TestNewCPAFieldFromPropertyField(t *testing.T) { name: "valid property field with all attributes", propertyField: &PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Test Field", Type: PropertyFieldTypeSelect, Attrs: StringInterface{ @@ -60,7 +59,7 @@ func TestNewCPAFieldFromPropertyField(t *testing.T) { name: "valid property field with minimal attributes", propertyField: &PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Test Field", Type: PropertyFieldTypeText, Attrs: StringInterface{ @@ -79,22 +78,20 @@ func TestNewCPAFieldFromPropertyField(t *testing.T) { wantErr: false, }, { - name: "property field with empty attributes returns default values", + // Conversion is a pure data operation: empty PropertyField.Attrs + // produces empty CPAAttrs. The visibility default is applied at + // write time by AccessControlAttributeValidationHook, not at read time. + name: "property field with empty attributes returns empty CPAAttrs", propertyField: &PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Empty Field", Type: PropertyFieldTypeText, CreateAt: GetMillis(), UpdateAt: GetMillis(), }, - wantAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityWhenSet, // Defaults are applied during conversion - SortOrder: 0, - ValueType: "", - Options: nil, - }, - wantErr: false, + wantAttrs: CPAAttrs{}, + wantErr: false, }, } @@ -146,7 +143,7 @@ func TestCPAFieldToPropertyField(t *testing.T) { cpaField: &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Test Field", Type: PropertyFieldTypeSelect, CreateAt: GetMillis(), @@ -171,7 +168,7 @@ func TestCPAFieldToPropertyField(t *testing.T) { cpaField: &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Test Field", Type: PropertyFieldTypeText, CreateAt: GetMillis(), @@ -188,7 +185,7 @@ func TestCPAFieldToPropertyField(t *testing.T) { cpaField: &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Empty Field", Type: PropertyFieldTypeText, CreateAt: GetMillis(), @@ -238,7 +235,7 @@ func TestCPAFieldToPropertyField(t *testing.T) { cpaField: &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Managed Field", Type: PropertyFieldTypeText, CreateAt: GetMillis(), @@ -256,7 +253,7 @@ func TestCPAFieldToPropertyField(t *testing.T) { cpaField: &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "Non-managed Field", Type: PropertyFieldTypeText, CreateAt: GetMillis(), @@ -390,565 +387,8 @@ func TestCustomProfileAttributeSelectOptionIsValid(t *testing.T) { } } -func TestCPAField_SanitizeAndValidate(t *testing.T) { - tests := []struct { - name string - field *CPAField - expectError bool - errorId string - expectedAttrs CPAAttrs - checkOptionsID bool - }{ - { - name: "valid text field with no value type", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: "when_set", - }, - }, - { - name: "valid text field with valid value type and whitespace", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - ValueType: " email ", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: "when_set", - ValueType: CustomProfileAttributesValueTypeEmail, - }, - }, - { - name: "valid text field with visibility and whitespace", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Visibility: " hidden ", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityHidden, - }, - }, - { - name: "invalid text field with invalid value type", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - ValueType: "invalid_type", - }, - }, - expectError: true, - errorId: "app.custom_profile_attributes.sanitize_and_validate.app_error", - }, - { - name: "valid select field with valid options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeSelect, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - Name: "Option 1", - Color: "#123456", - }, - { - Name: "Option 2", - Color: "#654321", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {Name: "Option 1", Color: "#123456"}, - {Name: "Option 2", Color: "#654321"}, - }, - }, - }, - { - name: "valid select field with valid options with ids", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeSelect, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: "t9ceh651eir4zkhyh4m54s5r7w", - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {ID: "t9ceh651eir4zkhyh4m54s5r7w", Name: "Option 1", Color: "#123456"}, - }, - }, - checkOptionsID: true, - }, - { - name: "invalid select field with duplicate option names", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeSelect, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - Name: "Option 1", - Color: "opt1", - }, - { - Name: "Option 1", - Color: "opt2", - }, - }, - }, - }, - expectError: true, - errorId: "app.custom_profile_attributes.sanitize_and_validate.app_error", - }, - { - name: "invalid field with unknown visibility", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Visibility: "unknown", - }, - }, - expectError: true, - errorId: "app.custom_profile_attributes.sanitize_and_validate.app_error", - }, - - // Test options cleaning for types that don't support options - { - name: "text field with options should clean options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: nil, // Options should be cleaned - }, - }, - { - name: "date field with options should clean options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeDate, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: nil, // Options should be cleaned - }, - }, - { - name: "user field with options should clean options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeUser, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: nil, // Options should be cleaned - }, - }, - - // Test options preservation for types that support options - { - name: "select field with options should preserve options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeSelect, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {Name: "Option 1", Color: "#123456"}, - }, - }, - }, - { - name: "multiselect field with options should preserve options", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeMultiselect, - }, - Attrs: CPAAttrs{ - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {Name: "Option 1", Color: "#123456"}, - }, - }, - }, - - // Test syncing attributes cleaning for types that don't support syncing - { - name: "select field with LDAP and SAML should clean syncing attributes", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeSelect, - }, - Attrs: CPAAttrs{ - LDAP: "ldap_attribute", - SAML: "saml_attribute", - Options: []*CustomProfileAttributesSelectOption{ - { - ID: NewId(), - Name: "Option 1", - Color: "#123456", - }, - }, - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - LDAP: "", // Should be cleaned - SAML: "", // Should be cleaned - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {Name: "Option 1", Color: "#123456"}, - }, - }, - }, - { - name: "date field with LDAP and SAML should clean syncing attributes", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeDate, - }, - Attrs: CPAAttrs{ - LDAP: "ldap_attribute", - SAML: "saml_attribute", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - LDAP: "", // Should be cleaned - SAML: "", // Should be cleaned - }, - }, - - // Test syncing attributes preservation for types that support syncing - { - name: "text field with LDAP and SAML should preserve syncing attributes", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - LDAP: "ldap_attribute", - SAML: "saml_attribute", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - LDAP: "ldap_attribute", // Should be preserved - SAML: "saml_attribute", // Should be preserved - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.field.SanitizeAndValidate() - if tt.expectError { - require.NotNil(t, err) - require.Equal(t, tt.errorId, err.Id) - } else { - var ogErr error - if err != nil { - ogErr = err.Unwrap() - } - require.Nilf(t, err, "unexpected error: %v, with original error: %v", err, ogErr) - - assert.Equal(t, tt.expectedAttrs.Visibility, tt.field.Attrs.Visibility) - assert.Equal(t, tt.expectedAttrs.ValueType, tt.field.Attrs.ValueType) - - for i := range tt.expectedAttrs.Options { - if tt.checkOptionsID { - assert.Equal(t, tt.expectedAttrs.Options[i].ID, tt.field.Attrs.Options[i].ID) - } - assert.Equal(t, tt.expectedAttrs.Options[i].Name, tt.field.Attrs.Options[i].Name) - assert.Equal(t, tt.expectedAttrs.Options[i].Color, tt.field.Attrs.Options[i].Color) - } - } - }) - } - - // Test managed fields functionality - t.Run("managed fields", func(t *testing.T) { - managedTests := []struct { - name string - field *CPAField - expectError bool - errorId string - expectedAttrs CPAAttrs - }{ - { - name: "valid managed field with admin value", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Managed: "admin", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Managed: "admin", - }, - }, - { - name: "managed field with whitespace should be trimmed", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Managed: " admin ", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Managed: "admin", - }, - }, - { - name: "field with empty managed should be allowed", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Managed: "", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Managed: "", - }, - }, - { - name: "field with invalid managed value should fail", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Managed: "invalid", - }, - }, - expectError: true, - errorId: "app.custom_profile_attributes.sanitize_and_validate.app_error", - }, - { - name: "managed field should clear LDAP sync properties", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - Managed: "admin", - LDAP: "ldap_attribute", - SAML: "saml_attribute", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Managed: "admin", - LDAP: "", // Should be cleared - SAML: "", // Should be cleared - }, - }, - { - name: "managed field should clear sync properties even when field supports syncing", - field: &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, // Text fields support syncing - }, - Attrs: CPAAttrs{ - Managed: "admin", - LDAP: "ldap_attribute", - }, - }, - expectError: false, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - Managed: "admin", - LDAP: "", // Should be cleared due to mutual exclusivity - SAML: "", - }, - }, - } - - for _, tt := range managedTests { - t.Run(tt.name, func(t *testing.T) { - err := tt.field.SanitizeAndValidate() - if tt.expectError { - require.NotNil(t, err) - require.Equal(t, tt.errorId, err.Id) - } else { - require.Nil(t, err) - assert.Equal(t, tt.expectedAttrs.Visibility, tt.field.Attrs.Visibility) - assert.Equal(t, tt.expectedAttrs.Managed, tt.field.Attrs.Managed) - assert.Equal(t, tt.expectedAttrs.LDAP, tt.field.Attrs.LDAP) - assert.Equal(t, tt.expectedAttrs.SAML, tt.field.Attrs.SAML) - } - }) - } - }) - - t.Run("display_name sanitization", func(t *testing.T) { - displayNameTests := []struct { - name string - displayName string - expectError bool - errorId string - expectedValue string - }{ - { - name: "empty display_name is allowed", - displayName: "", - expectError: false, - expectedValue: "", - }, - { - name: "display_name with surrounding whitespace is trimmed", - displayName: " Department Head ", - expectError: false, - expectedValue: "Department Head", - }, - { - name: "all-whitespace display_name is trimmed to empty and allowed", - displayName: " ", - expectError: false, - expectedValue: "", - }, - { - name: "display_name at exactly 255 runes is accepted", - displayName: strings.Repeat("a", PropertyFieldNameMaxRunes), - expectError: false, - expectedValue: strings.Repeat("a", PropertyFieldNameMaxRunes), - }, - { - name: "display_name at 256 runes is rejected", - displayName: strings.Repeat("a", PropertyFieldNameMaxRunes+1), - expectError: true, - errorId: "app.custom_profile_attributes.sanitize_and_validate.display_name_too_long.app_error", - }, - } - - for _, tt := range displayNameTests { - t.Run(tt.name, func(t *testing.T) { - field := &CPAField{ - PropertyField: PropertyField{ - Type: PropertyFieldTypeText, - }, - Attrs: CPAAttrs{ - DisplayName: tt.displayName, - }, - } - appErr := field.SanitizeAndValidate() - if tt.expectError { - require.NotNil(t, appErr) - require.Equal(t, tt.errorId, appErr.Id) - } else { - require.Nil(t, appErr) - assert.Equal(t, tt.expectedValue, field.Attrs.DisplayName, - "DisplayName must be trimmed after SanitizeAndValidate") - } - }) - } - }) -} +// TestCPAField_SanitizeAndValidate removed: behavior moved into AccessControlAttributeValidationHook; +// see TestAccessControlAttributeValidationHook in server/channels/app/properties/access_control_attribute_validation_test.go. func TestValidateCPAFieldName(t *testing.T) { tests := []struct { @@ -1016,7 +456,7 @@ func TestCPAField_ToPropertyField_DisplayName(t *testing.T) { original := &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "department", Type: PropertyFieldTypeText, }, @@ -1043,7 +483,7 @@ func TestCPAField_ToPropertyField_DisplayName(t *testing.T) { field := &CPAField{ PropertyField: PropertyField{ ID: NewId(), - GroupID: CustomProfileAttributesPropertyGroupName, + GroupID: AccessControlPropertyGroupName, Name: "department", Type: PropertyFieldTypeText, }, @@ -1061,203 +501,6 @@ func TestCPAField_ToPropertyField_DisplayName(t *testing.T) { }) } -func TestSanitizeAndValidatePropertyValue(t *testing.T) { - t.Run("text field type", func(t *testing.T) { - t.Run("valid text", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeText}}, json.RawMessage(`"hello world"`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Equal(t, "hello world", value) - }) - - t.Run("empty text should be allowed", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeText}}, json.RawMessage(`""`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Empty(t, value) - }) - - t.Run("invalid JSON", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeText}}, json.RawMessage(`invalid`)) - require.Error(t, err) - }) - - t.Run("wrong type", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeText}}, json.RawMessage(`123`)) - require.Error(t, err) - require.Contains(t, err.Error(), "json: cannot unmarshal number into Go value of type string") - }) - - t.Run("value too long", func(t *testing.T) { - longValue := strings.Repeat("a", CPAValueTypeTextMaxLength+1) - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeText}}, json.RawMessage(fmt.Sprintf(`"%s"`, longValue))) - require.Error(t, err) - require.Equal(t, "value too long", err.Error()) - }) - }) - - t.Run("date field type", func(t *testing.T) { - t.Run("valid date", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeDate}}, json.RawMessage(`"2023-01-01"`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Equal(t, "2023-01-01", value) - }) - - t.Run("empty date should be allowed", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeDate}}, json.RawMessage(`""`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Empty(t, value) - }) - }) - - t.Run("select field type", func(t *testing.T) { - t.Run("valid option", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{ - PropertyField: PropertyField{Type: PropertyFieldTypeSelect}, - Attrs: CPAAttrs{ - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {ID: "option1"}, - }, - }}, json.RawMessage(`"option1"`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Equal(t, "option1", value) - }) - - t.Run("invalid option", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeSelect}}, json.RawMessage(`"option1"`)) - require.Error(t, err) - }) - - t.Run("empty option should be allowed", func(t *testing.T) { - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeSelect}}, json.RawMessage(`""`)) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Empty(t, value) - }) - }) - - t.Run("user field type", func(t *testing.T) { - t.Run("valid user ID", func(t *testing.T) { - validID := NewId() - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeUser}}, json.RawMessage(fmt.Sprintf(`"%s"`, validID))) - require.NoError(t, err) - var value string - require.NoError(t, json.Unmarshal(result, &value)) - require.Equal(t, validID, value) - }) - - t.Run("empty user ID should be allowed", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeUser}}, json.RawMessage(`""`)) - require.NoError(t, err) - }) - - t.Run("invalid user ID format", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeUser}}, json.RawMessage(`"invalid-id"`)) - require.Error(t, err) - require.Equal(t, "invalid user id", err.Error()) - }) - }) - - t.Run("multiselect field type", func(t *testing.T) { - t.Run("valid options", func(t *testing.T) { - option1ID := NewId() - option2ID := NewId() - option3ID := NewId() - result, err := SanitizeAndValidatePropertyValue(&CPAField{ - PropertyField: PropertyField{Type: PropertyFieldTypeMultiselect}, - Attrs: CPAAttrs{ - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {ID: option1ID}, - {ID: option2ID}, - {ID: option3ID}, - }, - }}, json.RawMessage(fmt.Sprintf(`["%s", "%s"]`, option1ID, option2ID))) - require.NoError(t, err) - var values []string - require.NoError(t, json.Unmarshal(result, &values)) - require.Equal(t, []string{option1ID, option2ID}, values) - }) - - t.Run("empty array", func(t *testing.T) { - option1ID := NewId() - option2ID := NewId() - option3ID := NewId() - _, err := SanitizeAndValidatePropertyValue(&CPAField{ - PropertyField: PropertyField{Type: PropertyFieldTypeMultiselect}, - Attrs: CPAAttrs{ - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {ID: option1ID}, - {ID: option2ID}, - {ID: option3ID}, - }, - }}, json.RawMessage(`[]`)) - require.NoError(t, err) - }) - - t.Run("array with empty values should filter them out", func(t *testing.T) { - option1ID := NewId() - option2ID := NewId() - option3ID := NewId() - result, err := SanitizeAndValidatePropertyValue(&CPAField{ - PropertyField: PropertyField{Type: PropertyFieldTypeMultiselect}, - Attrs: CPAAttrs{ - Options: PropertyOptions[*CustomProfileAttributesSelectOption]{ - {ID: option1ID}, - {ID: option2ID}, - {ID: option3ID}, - }, - }}, json.RawMessage(fmt.Sprintf(`["%s", "", "%s", " ", "%s"]`, option1ID, option2ID, option3ID))) - require.NoError(t, err) - var values []string - require.NoError(t, json.Unmarshal(result, &values)) - require.Equal(t, []string{option1ID, option2ID, option3ID}, values) - }) - }) - - t.Run("multiuser field type", func(t *testing.T) { - t.Run("valid user IDs", func(t *testing.T) { - validID1 := NewId() - validID2 := NewId() - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeMultiuser}}, json.RawMessage(fmt.Sprintf(`["%s", "%s"]`, validID1, validID2))) - require.NoError(t, err) - var values []string - require.NoError(t, json.Unmarshal(result, &values)) - require.Equal(t, []string{validID1, validID2}, values) - }) - - t.Run("empty array", func(t *testing.T) { - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeMultiuser}}, json.RawMessage(`[]`)) - require.NoError(t, err) - }) - - t.Run("array with empty strings should be filtered out", func(t *testing.T) { - validID1 := NewId() - validID2 := NewId() - result, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeMultiuser}}, json.RawMessage(fmt.Sprintf(`["%s", "", " ", "%s"]`, validID1, validID2))) - require.NoError(t, err) - var values []string - require.NoError(t, json.Unmarshal(result, &values)) - require.Equal(t, []string{validID1, validID2}, values) - }) - - t.Run("array with invalid ID should return error", func(t *testing.T) { - validID1 := NewId() - _, err := SanitizeAndValidatePropertyValue(&CPAField{PropertyField: PropertyField{Type: PropertyFieldTypeMultiuser}}, json.RawMessage(fmt.Sprintf(`["%s", "invalid-id"]`, validID1))) - require.Error(t, err) - require.Equal(t, "invalid user id: invalid-id", err.Error()) - }) - }) -} - func TestCPAField_IsAdminManaged(t *testing.T) { tests := []struct { name string @@ -1308,71 +551,8 @@ func TestCPAField_IsAdminManaged(t *testing.T) { } } -func TestCPAField_SetDefaults(t *testing.T) { - testCases := []struct { - name string - field *CPAField - expectedAttrs CPAAttrs - }{ - { - name: "field with empty visibility should set default", - field: &CPAField{ - Attrs: CPAAttrs{ - Visibility: "", - SortOrder: 5.0, - }, - }, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - SortOrder: 5.0, - }, - }, - { - name: "field with existing visibility should not change", - field: &CPAField{ - Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityAlways, - SortOrder: 10.0, - }, - }, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityAlways, - SortOrder: 10.0, - }, - }, - { - name: "field with zero values should set visibility default, keep sort order zero", - field: &CPAField{ - Attrs: CPAAttrs{}, - }, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityDefault, - SortOrder: 0.0, - }, - }, - { - name: "field with hidden visibility should preserve it", - field: &CPAField{ - Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityHidden, - SortOrder: 3.5, - }, - }, - expectedAttrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityHidden, - SortOrder: 3.5, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.field.SetDefaults() - assert.Equal(t, tc.expectedAttrs.Visibility, tc.field.Attrs.Visibility) - assert.Equal(t, tc.expectedAttrs.SortOrder, tc.field.Attrs.SortOrder) - }) - } -} +// TestCPAField_SetDefaults removed: visibility default is now applied by AccessControlAttributeValidationHook +// (see access_control_attribute_validation.go), exercised in TestAccessControlAttributeValidationHook. func TestCPAField_Patch(t *testing.T) { testCases := []struct { @@ -1508,6 +688,10 @@ func TestCPAField_Patch(t *testing.T) { expectError: false, }, { + // Patch with non-nil Attrs replaces the whole Attrs map; visibility + // drops to "" because the patch doesn't include it. The visibility + // default is reapplied at write time by AccessControlAttributeValidationHook, + // not by Patch itself. name: "patch sort order", field: &CPAField{ PropertyField: PropertyField{ @@ -1534,8 +718,7 @@ func TestCPAField_Patch(t *testing.T) { Type: PropertyFieldTypeText, }, Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityWhenSet, - SortOrder: 10.5, + SortOrder: 10.5, }, }, expectError: false, @@ -1567,8 +750,7 @@ func TestCPAField_Patch(t *testing.T) { Type: PropertyFieldTypeText, }, Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityWhenSet, - Managed: "admin", + Managed: "admin", }, }, expectError: false, @@ -1599,8 +781,7 @@ func TestCPAField_Patch(t *testing.T) { Type: PropertyFieldTypeText, }, Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityWhenSet, - LDAP: "ldap_attribute", + LDAP: "ldap_attribute", }, }, expectError: false, @@ -1637,7 +818,6 @@ func TestCPAField_Patch(t *testing.T) { Type: PropertyFieldTypeSelect, }, Attrs: CPAAttrs{ - Visibility: CustomProfileAttributesVisibilityWhenSet, Options: []*CustomProfileAttributesSelectOption{ {ID: "opt1", Name: "Option 1"}, {ID: "opt2", Name: "Option 2"}, @@ -1785,3 +965,87 @@ func TestCPAField_Patch(t *testing.T) { }) } } + +func TestCPAFieldsFromPropertyFields(t *testing.T) { + mkField := func(name string, sortOrder float64) *PropertyField { + return &PropertyField{ + ID: NewId(), + GroupID: AccessControlPropertyGroupName, + Name: name, + Type: PropertyFieldTypeText, + Attrs: StringInterface{ + CustomProfileAttributesPropertyAttrsSortOrder: sortOrder, + }, + } + } + + t.Run("empty slice returns empty slice", func(t *testing.T) { + result, err := CPAFieldsFromPropertyFields(nil) + require.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("sorts by SortOrder ascending", func(t *testing.T) { + input := []*PropertyField{ + mkField("c", 2), + mkField("a", 0), + mkField("b", 1), + } + + result, err := CPAFieldsFromPropertyFields(input) + require.NoError(t, err) + require.Len(t, result, 3) + assert.Equal(t, "a", result[0].Name) + assert.Equal(t, "b", result[1].Name) + assert.Equal(t, "c", result[2].Name) + }) + + t.Run("preserves fields with equal SortOrder in encounter order", func(t *testing.T) { + input := []*PropertyField{ + mkField("first", 0), + mkField("second", 0), + } + + result, err := CPAFieldsFromPropertyFields(input) + require.NoError(t, err) + require.Len(t, result, 2) + // sort.Slice is not stable, but the test asserts both possible stable outcomes + // — we care that both fields are present, not stability. + names := []string{result[0].Name, result[1].Name} + assert.Contains(t, names, "first") + assert.Contains(t, names, "second") + }) + + t.Run("propagates conversion errors", func(t *testing.T) { + // options stored as an invalid JSON-marshallable type so that + // json.Marshal fails inside NewCPAFieldFromPropertyField + input := []*PropertyField{{ + ID: NewId(), + GroupID: AccessControlPropertyGroupName, + Name: "bad", + Type: PropertyFieldTypeText, + Attrs: StringInterface{ + PropertyFieldAttributeOptions: make(chan int), + }, + }} + + result, err := CPAFieldsFromPropertyFields(input) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("preserves empty visibility from PropertyField (defaults are applied at write time by AccessControlAttributeValidationHook, not at read time)", func(t *testing.T) { + input := []*PropertyField{{ + ID: NewId(), + GroupID: AccessControlPropertyGroupName, + Name: "no_visibility", + Type: PropertyFieldTypeText, + Attrs: StringInterface{}, + }} + + result, err := CPAFieldsFromPropertyFields(input) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Empty(t, result[0].Attrs.Visibility) + }) +} diff --git a/server/public/model/property_access_control.go b/server/public/model/property_access_control.go index 21bdea6d868..1bab63c96d4 100644 --- a/server/public/model/property_access_control.go +++ b/server/public/model/property_access_control.go @@ -11,6 +11,25 @@ type AccessControlContextKey string // AccessControlCallerIDContextKey is the context key for access control caller ID. const AccessControlCallerIDContextKey AccessControlContextKey = "access_control_caller_id" +// Well-known caller IDs for internal services that need to write property +// values on synced fields. These are set on the request context by the +// respective sync services so that the access control hook can identify them. +// +// The "system:" prefix contains a colon, which is not a valid character in a +// plugin ID (see IsValidPluginId). That guarantees these values cannot be +// forged by a plugin whose manifest ID is used as its caller ID. +// +// CallerIDLocalAdmin marks a request as originating from a local-mode +// (unrestricted) session, which has an empty Session.UserId but full admin +// privileges. HTTP handlers tag the rctx with this caller ID when +// Session().IsUnrestricted() is true, so the attribute validation hook's +// permission checker can grant admin privileges without a user lookup. +const ( + CallerIDLDAPSync = "system:ldap_sync" + CallerIDSAMLSync = "system:saml_sync" + CallerIDLocalAdmin = "system:local_admin" +) + // WithCallerID adds the caller ID to a context.Context for access control purposes. func WithCallerID(ctx context.Context, callerID string) context.Context { return context.WithValue(ctx, AccessControlCallerIDContextKey, callerID) diff --git a/server/public/model/property_field.go b/server/public/model/property_field.go index 25027c8e16c..73c64e445f1 100644 --- a/server/public/model/property_field.go +++ b/server/public/model/property_field.go @@ -404,12 +404,8 @@ func (pf *PropertyField) Patch(patch *PropertyFieldPatch, mergeAttrs bool) { // Legacy properties have an empty ObjectType and rely on simple TargetID uniqueness // enforced by the idx_propertyfields_unique_legacy database constraint, rather than // the hierarchical uniqueness model used by PSAv2 (ObjectType-based) properties. -// -// FIXME: treating template fields as PSAv1 is a temporary measure until the -// CPA feature fully transitions to v2. Once that happens, remove the -// PropertyFieldObjectTypeTemplate check. func (pf *PropertyField) IsPSAv1() bool { - return pf.ObjectType == "" || pf.ObjectType == PropertyFieldObjectTypeTemplate + return pf.ObjectType == "" } // IsPSAv2 returns true if this property field uses the PSAv2 schema. diff --git a/server/public/model/property_field_attrs_validation.go b/server/public/model/property_field_attrs_validation.go new file mode 100644 index 00000000000..2e2924455c7 --- /dev/null +++ b/server/public/model/property_field_attrs_validation.go @@ -0,0 +1,192 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "encoding/json" + "fmt" + "net/url" + "strings" +) + +// Attribute keys used across property groups. These are the canonical keys +// stored in PropertyField.Attrs and referenced by hooks. +const ( + PropertyFieldAttrVisibility = "visibility" + PropertyFieldAttrSortOrder = "sort_order" + PropertyFieldAttrValueType = "value_type" + PropertyFieldAttrLDAP = "ldap" + PropertyFieldAttrSAML = "saml" + PropertyFieldAttrManaged = "managed" + PropertyFieldAttrDisplayName = "display_name" +) + +// Valid visibility values for property fields. +const ( + PropertyFieldVisibilityHidden = "hidden" + PropertyFieldVisibilityWhenSet = "when_set" + PropertyFieldVisibilityAlways = "always" +) + +// Valid value types for text property fields. +const ( + PropertyFieldValueTypeEmail = "email" + PropertyFieldValueTypeURL = "url" + PropertyFieldValueTypePhone = "phone" +) + +// PropertyFieldValueTypeTextMaxLength is the maximum character length for text field values. +const PropertyFieldValueTypeTextMaxLength = 64 + +// IsValidPropertyFieldVisibility reports whether the given string is a known visibility value. +func IsValidPropertyFieldVisibility(v string) bool { + switch v { + case PropertyFieldVisibilityHidden, + PropertyFieldVisibilityWhenSet, + PropertyFieldVisibilityAlways: + return true + default: + return false + } +} + +// IsValidPropertyFieldValueType reports whether the given string is a known value type. +func IsValidPropertyFieldValueType(v string) bool { + switch v { + case PropertyFieldValueTypeEmail, + PropertyFieldValueTypeURL, + PropertyFieldValueTypePhone: + return true + default: + return false + } +} + +// ValidatePropertyFieldVisibility checks that the visibility attr on a +// PropertyField is either empty or one of hidden/when_set/always. +func ValidatePropertyFieldVisibility(field *PropertyField) error { + if field.Attrs == nil { + return nil + } + + raw, ok := field.Attrs[PropertyFieldAttrVisibility] + if !ok { + return nil + } + + v, ok := raw.(string) + if !ok { + return fmt.Errorf("visibility must be a string") + } + + v = strings.TrimSpace(v) + if v == "" { + return nil + } + + if !IsValidPropertyFieldVisibility(v) { + return fmt.Errorf("invalid visibility %q: must be one of hidden, when_set, always", v) + } + + return nil +} + +// ValidatePropertyFieldSortOrder checks that the sort_order attr on a +// PropertyField is numeric (float64 or json.Number) or absent. +func ValidatePropertyFieldSortOrder(field *PropertyField) error { + if field.Attrs == nil { + return nil + } + + raw, ok := field.Attrs[PropertyFieldAttrSortOrder] + if !ok { + return nil + } + + switch raw.(type) { + case float64, json.Number, int, int64: + return nil + default: + return fmt.Errorf("sort_order must be numeric, got %T", raw) + } +} + +// ValidatePropertyValueForValueType validates a raw JSON value against the +// given value type constraint. This is called for text fields that have a +// value_type attr (email, url, phone). +func ValidatePropertyValueForValueType(valueType string, value json.RawMessage) error { + if valueType == "" { + return nil + } + + var str string + if err := json.Unmarshal(value, &str); err != nil { + return fmt.Errorf("expected string value for value_type %q: %w", valueType, err) + } + + str = strings.TrimSpace(str) + if str == "" { + return nil + } + + switch valueType { + case PropertyFieldValueTypeEmail: + if !IsValidEmail(str) { + return fmt.Errorf("invalid email: %q", str) + } + case PropertyFieldValueTypeURL: + // ParseRequestURI rejects relative references (url.Parse accepts them), + // and we additionally require a non-empty Host so bare schemes like + // "http:" or "file:///..." without an authority are rejected. + u, err := url.ParseRequestURI(str) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("invalid url: %q", str) + } + case PropertyFieldValueTypePhone: + // Phone values are accepted as-is; no structural validation. + default: + return fmt.Errorf("unknown value_type %q", valueType) + } + + return nil +} + +// GetPropertyFieldValueType extracts the value_type string from a +// PropertyField's attrs. Returns empty string if not set. +func GetPropertyFieldValueType(field *PropertyField) string { + if field.Attrs == nil { + return "" + } + v, _ := field.Attrs[PropertyFieldAttrValueType].(string) + return strings.TrimSpace(v) +} + +// IsPropertyFieldSynced reports whether the field has an ldap or saml attr set, +// meaning its values are managed by an external sync service. +func IsPropertyFieldSynced(field *PropertyField) bool { + if field.Attrs == nil { + return false + } + ldap, _ := field.Attrs[PropertyFieldAttrLDAP].(string) + saml, _ := field.Attrs[PropertyFieldAttrSAML].(string) + return ldap != "" || saml != "" +} + +// GetPropertyFieldSyncSource returns the sync source for a field: "ldap", +// "saml", or empty string if not synced. If both are set, ldap takes priority. +func GetPropertyFieldSyncSource(field *PropertyField) string { + if field.Attrs == nil { + return "" + } + if ldap, _ := field.Attrs[PropertyFieldAttrLDAP].(string); ldap != "" { + return "ldap" + } + if saml, _ := field.Attrs[PropertyFieldAttrSAML].(string); saml != "" { + return "saml" + } + return "" +} diff --git a/server/public/model/property_field_attrs_validation_test.go b/server/public/model/property_field_attrs_validation_test.go new file mode 100644 index 00000000000..a90f4902fb7 --- /dev/null +++ b/server/public/model/property_field_attrs_validation_test.go @@ -0,0 +1,157 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidatePropertyFieldVisibility(t *testing.T) { + tests := []struct { + name string + attrs StringInterface + wantErr bool + }{ + {name: "nil attrs", attrs: nil}, + {name: "no visibility key", attrs: StringInterface{"other": "val"}}, + {name: "empty string", attrs: StringInterface{PropertyFieldAttrVisibility: ""}}, + {name: "hidden", attrs: StringInterface{PropertyFieldAttrVisibility: "hidden"}}, + {name: "when_set", attrs: StringInterface{PropertyFieldAttrVisibility: "when_set"}}, + {name: "always", attrs: StringInterface{PropertyFieldAttrVisibility: "always"}}, + {name: "invalid", attrs: StringInterface{PropertyFieldAttrVisibility: "public"}, wantErr: true}, + {name: "non-string type", attrs: StringInterface{PropertyFieldAttrVisibility: 42}, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + field := &PropertyField{Attrs: tt.attrs} + err := ValidatePropertyFieldVisibility(field) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidatePropertyFieldSortOrder(t *testing.T) { + tests := []struct { + name string + attrs StringInterface + wantErr bool + }{ + {name: "nil attrs", attrs: nil}, + {name: "no sort_order key", attrs: StringInterface{"other": "val"}}, + {name: "float64", attrs: StringInterface{PropertyFieldAttrSortOrder: float64(1.5)}}, + {name: "int", attrs: StringInterface{PropertyFieldAttrSortOrder: 1}}, + {name: "int64", attrs: StringInterface{PropertyFieldAttrSortOrder: int64(42)}}, + {name: "json.Number", attrs: StringInterface{PropertyFieldAttrSortOrder: json.Number("3.14")}}, + {name: "string", attrs: StringInterface{PropertyFieldAttrSortOrder: "not_a_number"}, wantErr: true}, + {name: "bool", attrs: StringInterface{PropertyFieldAttrSortOrder: true}, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + field := &PropertyField{Attrs: tt.attrs} + err := ValidatePropertyFieldSortOrder(field) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidatePropertyValueForValueType(t *testing.T) { + tests := []struct { + name string + valueType string + value string + wantErr bool + }{ + {name: "empty value type", valueType: "", value: `"anything"`}, + {name: "valid email", valueType: "email", value: `"test@example.com"`}, + {name: "invalid email", valueType: "email", value: `"not-an-email"`, wantErr: true}, + {name: "empty email string", valueType: "email", value: `""`}, + {name: "valid url", valueType: "url", value: `"https://example.com"`}, + {name: "valid url with path", valueType: "url", value: `"https://example.com/path?q=1"`}, + {name: "invalid url - plain string", valueType: "url", value: `"not a url"`, wantErr: true}, + {name: "invalid url - relative path", valueType: "url", value: `"/relative/path"`, wantErr: true}, + {name: "invalid url - missing host", valueType: "url", value: `"http://"`, wantErr: true}, + {name: "invalid url - missing scheme", valueType: "url", value: `"example.com"`, wantErr: true}, + {name: "phone (any string)", valueType: "phone", value: `"+1-555-0123"`}, + {name: "unknown value type", valueType: "fax", value: `"test"`, wantErr: true}, + {name: "non-string json", valueType: "email", value: `42`, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePropertyValueForValueType(tt.valueType, json.RawMessage(tt.value)) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestIsPropertyFieldSynced(t *testing.T) { + assert.False(t, IsPropertyFieldSynced(&PropertyField{})) + assert.False(t, IsPropertyFieldSynced(&PropertyField{Attrs: StringInterface{}})) + assert.True(t, IsPropertyFieldSynced(&PropertyField{Attrs: StringInterface{PropertyFieldAttrLDAP: "attr"}})) + assert.True(t, IsPropertyFieldSynced(&PropertyField{Attrs: StringInterface{PropertyFieldAttrSAML: "attr"}})) + assert.True(t, IsPropertyFieldSynced(&PropertyField{Attrs: StringInterface{PropertyFieldAttrLDAP: "a", PropertyFieldAttrSAML: "b"}})) +} + +func TestGetPropertyFieldSyncSource(t *testing.T) { + assert.Equal(t, "", GetPropertyFieldSyncSource(&PropertyField{})) + assert.Equal(t, "ldap", GetPropertyFieldSyncSource(&PropertyField{Attrs: StringInterface{PropertyFieldAttrLDAP: "attr"}})) + assert.Equal(t, "saml", GetPropertyFieldSyncSource(&PropertyField{Attrs: StringInterface{PropertyFieldAttrSAML: "attr"}})) + // ldap takes priority + assert.Equal(t, "ldap", GetPropertyFieldSyncSource(&PropertyField{Attrs: StringInterface{PropertyFieldAttrLDAP: "a", PropertyFieldAttrSAML: "b"}})) +} + +func TestIsValidPropertyFieldVisibility(t *testing.T) { + assert.True(t, IsValidPropertyFieldVisibility("hidden")) + assert.True(t, IsValidPropertyFieldVisibility("when_set")) + assert.True(t, IsValidPropertyFieldVisibility("always")) + assert.False(t, IsValidPropertyFieldVisibility("")) + assert.False(t, IsValidPropertyFieldVisibility("public")) +} + +func TestIsValidPropertyFieldValueType(t *testing.T) { + assert.True(t, IsValidPropertyFieldValueType("email")) + assert.True(t, IsValidPropertyFieldValueType("url")) + assert.True(t, IsValidPropertyFieldValueType("phone")) + assert.False(t, IsValidPropertyFieldValueType("")) + assert.False(t, IsValidPropertyFieldValueType("fax")) +} + +func TestGetPropertyFieldValueType(t *testing.T) { + assert.Equal(t, "", GetPropertyFieldValueType(&PropertyField{})) + assert.Equal(t, "", GetPropertyFieldValueType(&PropertyField{Attrs: StringInterface{}})) + assert.Equal(t, "email", GetPropertyFieldValueType(&PropertyField{Attrs: StringInterface{PropertyFieldAttrValueType: "email"}})) + assert.Equal(t, "email", GetPropertyFieldValueType(&PropertyField{Attrs: StringInterface{PropertyFieldAttrValueType: " email "}})) +} + +func TestCallerIDConstants(t *testing.T) { + require.NotEmpty(t, CallerIDLDAPSync) + require.NotEmpty(t, CallerIDSAMLSync) + require.NotEqual(t, CallerIDLDAPSync, CallerIDSAMLSync) + + // The sync caller IDs must not be valid plugin IDs, otherwise an + // admin-installed plugin could set its manifest ID to one of these + // values and bypass the sync-lock check for LDAP/SAML-managed fields. + require.False(t, IsValidPluginId(CallerIDLDAPSync), + "CallerIDLDAPSync must not be a valid plugin ID") + require.False(t, IsValidPluginId(CallerIDSAMLSync), + "CallerIDSAMLSync must not be a valid plugin ID") +} diff --git a/server/public/model/property_group.go b/server/public/model/property_group.go index 0d6644902b7..9ed5cf0ed9e 100644 --- a/server/public/model/property_group.go +++ b/server/public/model/property_group.go @@ -8,6 +8,23 @@ import ( "regexp" ) +const AccessControlPropertyGroupName = "access_control" + +// DeprecatedCPAPropertyGroupName is the old group name for custom profile attributes. +// It was renamed to "access_control". The plugin API still accepts this name +// for backward compatibility, but plugin authors should migrate to +// AccessControlPropertyGroupName. +const DeprecatedCPAPropertyGroupName = "custom_profile_attributes" + +// AccessControlGroupFieldLimit is the global cap on the number of +// property fields that can exist in the access_control group across +// all object types. Call sites read all fields/values in a single page +// (PerPage = AccessControlGroupFieldLimit + 5) instead of paginating, +// on the assumption that the result set is bounded by this limit. If the +// limit is ever raised significantly or removed, every call site that uses +// AccessControlGroupFieldLimit + 5 must be converted to paginate. +const AccessControlGroupFieldLimit = 200 + var validPropertyGroupNameRegex = regexp.MustCompile(`^[a-z0-9][a-z0-9_]*$`) const ( diff --git a/server/public/model/property_value.go b/server/public/model/property_value.go index 43d50db4170..665e23ef483 100644 --- a/server/public/model/property_value.go +++ b/server/public/model/property_value.go @@ -6,6 +6,7 @@ package model import ( "encoding/json" "net/http" + "strings" "unicode/utf8" "github.com/pkg/errors" @@ -151,3 +152,60 @@ type PropertyValuePatchItem struct { FieldID string `json:"field_id"` Value json.RawMessage `json:"value"` } + +// SanitizePropertyValue normalizes a raw property value's JSON: +// - a top-level JSON string has surrounding whitespace trimmed; +// - a top-level JSON array of strings has each element trimmed and empty +// entries dropped; +// - any other shape (numbers, booleans, objects, nested arrays) passes +// through unchanged. +// +// Returns the original bytes when no change is needed so callers can +// compare by identity if they want to skip writes. +func SanitizePropertyValue(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 { + return raw + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + trimmed := strings.TrimSpace(s) + if trimmed == s { + return raw + } + out, err := json.Marshal(trimmed) + if err != nil { + return raw + } + return out + } + + var arr []string + if err := json.Unmarshal(raw, &arr); err == nil { + filtered := make([]string, 0, len(arr)) + changed := false + for _, v := range arr { + t := strings.TrimSpace(v) + if t != v { + changed = true + } + if t == "" { + if v != "" { + changed = true + } + continue + } + filtered = append(filtered, t) + } + if !changed && len(filtered) == len(arr) { + return raw + } + out, err := json.Marshal(filtered) + if err != nil { + return raw + } + return out + } + + return raw +} diff --git a/server/public/model/property_value_test.go b/server/public/model/property_value_test.go index f0bacdb13b0..382fefb907f 100644 --- a/server/public/model/property_value_test.go +++ b/server/public/model/property_value_test.go @@ -252,3 +252,38 @@ func TestPropertyValueSearchCursor_IsValid(t *testing.T) { assert.Error(t, cursor.IsValid()) }) } + +func TestSanitizePropertyValue(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"empty bytes", "", ""}, + {"string trimmed", `" hello "`, `"hello"`}, + {"string unchanged", `"hello"`, `"hello"`}, + {"string all whitespace", `" "`, `""`}, + {"string already empty", `""`, `""`}, + {"string array trimmed and filtered", `[" a ", "", " ", "b"]`, `["a","b"]`}, + {"string array unchanged", `["a","b"]`, `["a","b"]`}, + {"string array all empty", `["", " ", ""]`, `[]`}, + {"number passthrough", `42`, `42`}, + {"boolean passthrough", `true`, `true`}, + {"null passthrough", `null`, `null`}, + {"object passthrough", `{"key":" val "}`, `{"key":" val "}`}, + {"nested array passthrough", `[["a","b"]]`, `[["a","b"]]`}, + {"mixed array passthrough", `["a",1]`, `["a",1]`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := SanitizePropertyValue(json.RawMessage(tc.in)) + assert.Equal(t, tc.want, string(got)) + }) + } + + t.Run("returns identity when no change", func(t *testing.T) { + raw := json.RawMessage(`"hello"`) + got := SanitizePropertyValue(raw) + assert.Equal(t, &raw[0], &got[0], "expected same backing array when unchanged") + }) +} diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx index 8176ddd649b..c448a7e2325 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.test.tsx @@ -177,6 +177,38 @@ describe('UserPropertyDotMenu', () => { expect(screen.getByText('Edit SAML link')).toBeInTheDocument(); }); + it('clears admin-managed by setting managed to empty string, not by removing the key', async () => { + const adminManagedField: UserPropertyField = { + ...baseField, + id: 'admin-managed-field', + attrs: { + ...baseField.attrs, + managed: 'admin', + }, + }; + + renderComponent(adminManagedField); + + const menuButton = screen.getByTestId(`user-property-field_dotmenu-${adminManagedField.id}`); + await userEvent.click(menuButton); + + const editableToggle = screen.getByRole('menuitemcheckbox', {name: /Editable by users/}); + await userEvent.click(editableToggle); + + // The server PATCH uses merge semantics: omitted keys are preserved. Toggling off + // admin-managed must send managed: '' explicitly; deleting the key would silently + // leave the field admin-managed on the server. + expect(updateField).toHaveBeenCalledWith({ + ...adminManagedField, + attrs: { + sort_order: 0, + visibility: 'when_set', + value_type: '', + managed: '', + }, + }); + }); + it('handles field duplication', async () => { renderComponent(); diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx index 9e4c99b5419..162e2d54544 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_dot_menu.tsx @@ -146,7 +146,10 @@ const DotMenu = ({ const newAttrs = {...field.attrs}; if (field.attrs.managed === 'admin') { - Reflect.deleteProperty(newAttrs, 'managed'); + // Server PATCH merges attrs and preserves keys absent from the body, so we + // assign '' rather than deleting the key — otherwise managed='admin' would + // silently persist on the server. + newAttrs.managed = ''; } else { newAttrs.managed = 'admin'; } diff --git a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts index ba0590238ef..dd37a7f3094 100644 --- a/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts +++ b/webapp/channels/src/components/admin_console/system_properties/user_properties_utils.ts @@ -86,15 +86,7 @@ export const useUserPropertyFields = () => { // update await Promise.all(process.edit.map(async (pendingItem) => { const {id, name, type, attrs} = pendingItem; - let patch = {name, type, attrs}; - - // clear options if not select/multiselect - if (type !== 'select' && type !== 'multiselect') { - const attrs = {...patch.attrs}; - Reflect.deleteProperty(attrs, 'options'); - - patch = {...patch, attrs}; - } + const patch = {name, type, attrs}; return Client4.patchCustomProfileAttributeField(id, patch). then((nextItem) => { From f604ec7a5ca540dd7a94a99781276fde34427ad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Thu, 14 May 2026 18:59:18 +0200 Subject: [PATCH 05/80] MM-68662: Add Azure Blob Storage filestore backend (#36498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Generalize file backend error types Replace S3FileBackendAuthError and S3FileBackendNoBucketError with backend-agnostic FileBackendAuthError and FileBackendNoBucketError so non-S3 drivers can return them and the admin "Test Connection" flow keeps surfacing useful messages. The old S3-prefixed names are kept as type aliases of the generic types so external code (plugins, historical consumers) continues to compile, and so existing S3 construction sites stay untouched. The type switch in connectionTestErrorToAppError now matches the generic types, with new i18n keys (test_connection_auth.app_error and test_connection_no_bucket.app_error) whose wording does not name S3. The old S3-specific i18n keys are dropped via `make i18n-extract` since they are no longer referenced from code; the api4 test that asserted on those keys is updated, and the Cypress `MM-T996 Amazon S3 connection error messaging` spec that asserted on the old user-facing string is updated to the new wording. ------ AI assisted commit * Pull in Azure SDK and uuid dependencies Bring in github.com/Azure/azure-sdk-for-go/sdk/azcore and .../sdk/storage/azblob (with .../sdk/internal as their indirect dependency). The two are needed by the upcoming Azure Blob Storage filestore backend and its lazy-Range-backed reader. The bump of golang.org/x/{crypto,net,sys,term,text} comes transitively from azblob's minimum versions. Also promotes github.com/google/uuid from indirect to direct, since the Azure backend uses it to generate block IDs that share the same wire format the SDK itself produces in UploadStream. ------ AI assisted commit * Add azureRangeReader, a seekable Range-backed blob reader A small standalone type that satisfies the FileBackend interface's ReadCloseSeeker + the broader io.ReaderAt contract on top of Azure Blob Storage HTTP Range requests. Lands as its own commit because the upcoming Azure FileBackend driver builds on it, and the reader itself is independently useful — and independently testable against a fake downloader without standing up an Azure client. Design notes: * Read opens an HTTP Range stream lazily at the current offset and reuses it for sequential reads. Seek to a different offset closes the open stream; the next Read re-opens it. * Seek to the same offset is a no-op and does not close the open stream, so callers like zip.NewReader that probe with redundant seeks don't kick off a fresh download. * ReadAt issues a dedicated ranged DownloadStream per call and does not touch the streaming cursor — matches the io.ReaderAt contract the bulk-import worker's zip.NewReader path relies on. * Close cancels the context (which any in-flight Azure call will observe and abort), stops the deadline timer, and closes the current body if any. It is safe to call when no body was ever opened. * CancelTimeout lets long-running consumers like the import worker opt out of the per-operation deadline that would otherwise kill multi-minute downloads partway through. The implementation talks to a small blobDownloader interface rather than *blob.Client directly so the unit tests can substitute a fake downloader that records every requested Range and tracks Close calls on the bodies it hands out. ------ AI assisted commit * Add Azure Blob Storage filestore driver Implements the FileBackend interface against Azure Blob Storage in a new azurestore.go (~520 LOC). The driver is not yet selectable via NewFileBackend's switch — that wiring lands in the next commit together with the admin config surface — but the driver itself is complete and self-contained behind the FileBackendSettings struct. Filesstore.go grows three pieces of supporting infrastructure that the driver consumes: * a `driverAzure = "azureblob"` constant alongside the existing driverS3 and driverLocal, * an Azure-specific block on FileBackendSettings (storage account, access key, container, path prefix, endpoint, SSL flag, request timeout), * a CheckMandatoryAzureFields validator that mirrors CheckMandatoryS3Fields. Behavioural notes that warrant calling out: * Reader returns the previously-added azureRangeReader, so reads stream lazily over HTTP Range and ReadAt is available for the bulk-import worker's zip.NewReader path. The deadline timer is armed before the initial GetProperties call so the HEAD itself is bounded. * WriteFile and AppendFile both go through StageBlock + CommitBlockList via a shared stageBlocks helper, never the SDK's UploadStream. UploadStream's small-payload fast path falls back to single-shot PutBlob, which leaves the resulting blob with no committed block list; a subsequent AppendFile that calls CommitBlockList on that blob would then clobber its content. Routing every write through the block-list mechanism keeps AppendFile correct regardless of payload size. * AppendFile stages the new chunk as one or more blocks and commits the existing committed block list plus the newly staged IDs. The new bytes go up exactly once — no re-download, no re-concatenate, no re-upload of the prior contents. * WriteFileContext does not wrap the caller-supplied context with its own timeout — that timeout is applied in WriteFile only, matching the S3 driver, so long-running TryWriteFileContext callers (like message-export bulk writes) opt out of the per-operation timeout the way the abstraction documents. Authentication is shared-key only for this drop; Microsoft Entra ID / managed identity is deferred to a follow-up. The endpoint is configurable so the same code targets the production Azure host (vhost style — {account}.blob.core.windows.net) or Azurite / Azure Government / sovereign clouds (path style — host[:port]/{account}). ------ AI assisted commit * Wire Azure backend into config, validation, and driver selection This commit registers the previously-added AzureFileBackend driver with the rest of the system. Until now the driver was usable only via direct construction; after this commit, `DriverName: "azureblob"` in config.json is a fully-supported deployment configuration. Five integration sites are touched: * `newFileBackend` in filesstore.go now dispatches `driverAzure` to NewAzureFileBackend, alongside the existing s3 and local cases. NewFileBackendSettingsFromConfig (and its export counterpart) gain an Azure branch that maps the model.FileSettings fields onto the Azure-specific FileBackendSettings fields. * `model.FileSettings` grows the user-facing Azure config schema: storage account, access key, container, path prefix, endpoint, SSL flag, request timeout, plus matching Export* fields for the dedicated export store. SetDefaults populates them so deployments that never opted into Azure don't carry nil pointers. `isValid` accepts the new ImageDriverAzure constant. * `Config.Sanitize()` masks AzureAccessKey and ExportAzureAccessKey the same way it masks AmazonS3SecretAccessKey, so the shared key never reaches an API consumer in plain text. * `desanitize()` restores the masked keys on a config write so a PATCH that doesn't touch the key doesn't clobber it with the FakeSetting placeholder. * `configSensitivePaths` covers both Azure key paths so audit diffs don't include them either. * `ConfigToFileBackendSettings` in the `mattermost db` CLI helper gets the Azure branch its production counterpart already has — without it, `mattermost db migrate` / `db downgrade` would fail on Azure-configured deployments with "missing azure storage account setting". Finally, the shared FileBackendTestSuite is now wired against Azurite via TestAzureFileBackendTestSuite, which skips when CI_AZURITE_HOST is unreachable. The test-infra wiring (the docker service, the env vars, the start_dependencies entry) landed in a previous PR; this commit is what makes the suite actually exercise the Azure driver end to end. ------ AI assisted commit * Validate Azure timeout and path prefix in Config.IsValid Parity with the S3-side checks that already cover AmazonS3RequestTimeoutMilliseconds and AmazonS3PathPrefix. Without these, a zero/negative AzureRequestTimeoutMilliseconds passes validation and later creates immediately-expired request contexts, and leading/trailing whitespace in AzurePathPrefix produces blob keys that don't match what the admin configured. Same checks added for the Export* counterparts. The file_driver.app_error translation is updated to mention the new 'azureblob' option alongside 'local' and 'amazons3'. ------ AI assisted commit * Stream zip entries from the Azure backend writeZipEntry was calling ReadFile, which loads the entire blob into memory before writing it to the archive. For large blobs or deep directories this spikes RSS or OOMs the goroutine. Switch to Reader (the streaming azureRangeReader) and io.Copy into the zip entry so memory stays bounded regardless of blob size. ------ AI assisted commit * Use a backend-agnostic fallback for FileBackendNoBucketError The fallback Error() message was "no such bucket", which leaks S3 terminology when an Azure caller returns the type with no wrapped Err. Use "no such bucket or container" so logs and external error handling stay neutral across backends. ------ AI assisted commit * Defend Azure path prefix against directory traversal Reject ".." in AzurePathPrefix and ExportAzurePathPrefix at config validation time, since path.Join collapses traversal segments and a prefix like "../other-tenant" would otherwise escape the configured isolation boundary. Harden the prefix helper as a second line of defense: if the joined path no longer sits inside pathPrefix, fall back to joining the prefix with the base name of the caller-supplied path. That preserves the prefix invariant for plugin and import paths that the upload code does not sanitize uniformly. ------ AI assisted commit * Honor SkipVerify when constructing the Azure client FileBackendSettings.SkipVerify is plumbed through from the System Console the same way it is for S3, so admins toggling the flag for self-signed endpoints (Azurite, sovereign clouds) get the behavior they expect without having to drop SSL entirely and send the shared key in clear text. ------ AI assisted commit * Warn when the Azure request timeout falls back to its default Config.IsValid already rejects non-positive AzureRequestTimeoutMilliseconds for any path that goes through config validation, so this warn only fires for direct callers that bypass validation (tests, helpers). Logging the substitution turns a silent coercion into something an operator can correlate against unexpected request behavior. ------ AI assisted commit * Cap Azure request timeout at 10 minutes Reject AzureRequestTimeoutMilliseconds values above the ceiling so an operator (or someone who has admin access) cannot effectively disable timeouts by setting the value to math.MaxInt64. A hung Azure call then holds a goroutine open until the OS gives up. Applies the same bound to ExportAzureRequestTimeoutMilliseconds. S3 has the same gap; treating it is out of scope here but worth a follow-up. ------ AI assisted commit * Refuse AppendFile on blobs without a committed block list A blob written by another tool (Azure portal, azcopy, a migration script, a plugin using Put Blob) has its content in the blob but an empty committed-block list. Committing a new block list against such a blob silently replaces the existing content with only the appended bytes. Check the blob's properties before staging when the committed-block list is empty, and refuse with a clear error if the blob has content. Same hazard for an admin pointing the backend at an existing container with pre-existing files. Adds an integration test against Azurite to lock the behavior in. ------ AI assisted commit * Surface truncated reads from azureRangeReader Read closed the body cleanly and returned io.EOF even when the remote stream terminated before the blob's content length. Callers (and any retry layer above) then accepted a partial blob as complete. ReadAt unconditionally rewrote io.ErrUnexpectedEOF to io.EOF, which made truncated downloads indistinguishable from clean reads. That is exactly what zip.NewReader consumes for archive readers, so the bulk-import worker would silently import partial archives. Read now closes the body, nils it, and returns io.ErrUnexpectedEOF when EOF arrives before offset reaches size. ReadAt only collapses ErrUnexpectedEOF to EOF when the full count was delivered and the stream was consumed to the end of the blob. Otherwise the truncation propagates with context. Both code paths are exercised by new fakeDownloader-backed tests. ------ AI assisted commit * Move container provisioning out of Azure TestConnection Auto-creating the container inside TestConnection meant a typo in the System Console (mattermosst instead of mattermost) silently provisioned an unwanted container in the admin's Azure subscription, with no audit log and no warning. They'd discover it later when uploads landed somewhere unexpected. TestConnection now returns FileBackendNoBucketError when the container is missing, mirroring the S3 contract. A new MakeContainer method mirrors S3FileBackend.MakeBucket, and Server.Start dispatches via two capability interfaces (bucketMaker / containerMaker) instead of a hard S3 type assertion — so the NoBucket error is no longer silently swallowed for backends Server.Start has not been taught about. ------ AI assisted commit * Carry file backend auth detail through to AppError The Test Connection button collapsed every typed backend failure into the same generic i18n message. Operators trying to debug bad credentials or a missing bucket only saw "Unable to authenticate against the file storage backend" with no SDK code to grep for in their logs. Use errors.As so the typed checks survive future wrapping, and pass the underlying error string through the NewAppError details argument. The AppError serializer surfaces that detail to the admin console alongside the translated message, so a bad S3 InvalidAccessKeyId or an Azure AuthenticationFailed shows up in the toast without an i18n schema change. ------ AI assisted commit * Remove non-ascii characters from comments ------ AI assisted commit * Make linter happy ------ AI assisted commit * Harden Azure prefix boundary check strings.HasPrefix on the joined path is a string-level check, not a path-level one, so a configured prefix of "mattermost" accepts a joined result of "mattermost-evil/...". A crafted caller path like "../mattermost-evil/secrets" would collapse via path.Join to that exact sibling and slip through the boundary check, escaping the configured prefix scope. Require the joined path to be the cleaned prefix itself or to start with the prefix followed by a path separator. The fallback path.Join uses the same cleaned prefix for consistency. ------ AI assisted commit * Provision Azurite container in standalone test setup The shared FileBackendTestSuite's SetupTest already handles a missing container by detecting FileBackendNoBucketError from TestConnection and calling MakeContainer, but TestAzureFileBackendAppendRefusesNonBlockBlob bypasses SetupTest and calls TestConnection directly. On a fresh Azurite instance the test would fail before exercising the append-refusal logic. Extract a newAzuriteBackend(t) helper alongside azuriteSettings(t) that builds the backend and ensures the container exists, mirroring the suite's setup. Use errors.As for forward compatibility with future wrapping. ------ AI assisted commit * Fix grammar in email-settings i18n string "Email settings has unset values." -> "Email settings have unset values." ------ AI assisted commit * Make Azure MakeContainer idempotent Treat a ContainerAlreadyExists response as success so that two nodes racing through TestConnection plus MakeContainer at boot both converge instead of having the loser fail. Mirrors how the S3 backend handles the equivalent BucketAlreadyOwnedByYou case. ------ AI assisted commit * Narrow AzureEndpoint comment to path-style only The setting only builds path-style URLs, so it cannot reach sovereign clouds like Azure Government or Azure China, which require vhost-style endpoints. Update the comment to reflect what the code actually does and document that sovereign-cloud support is out of scope. ------ AI assisted commit --- .../system_console/environment_spec.js | 2 +- server/channels/api4/system_test.go | 6 +- server/channels/app/file.go | 22 +- server/channels/app/server.go | 22 +- server/cmd/mattermost/commands/db.go | 13 + server/config/diff.go | 2 + server/config/utils.go | 6 + server/go.mod | 15 +- server/go.sum | 34 +- server/i18n/en.json | 28 +- .../platform/shared/filestore/azurestore.go | 638 ++++++++++++++++++ .../filestore/azurestore_rangereader.go | 160 +++++ .../filestore/azurestore_rangereader_test.go | 361 ++++++++++ .../shared/filestore/azurestore_test.go | 137 ++++ server/platform/shared/filestore/errors.go | 44 ++ .../platform/shared/filestore/filesstore.go | 53 ++ .../shared/filestore/filesstore_test.go | 18 +- server/platform/shared/filestore/s3store.go | 21 +- server/public/model/config.go | 111 ++- server/public/model/config_test.go | 53 ++ 20 files changed, 1688 insertions(+), 58 deletions(-) create mode 100644 server/platform/shared/filestore/azurestore.go create mode 100644 server/platform/shared/filestore/azurestore_rangereader.go create mode 100644 server/platform/shared/filestore/azurestore_rangereader_test.go create mode 100644 server/platform/shared/filestore/azurestore_test.go create mode 100644 server/platform/shared/filestore/errors.go diff --git a/e2e-tests/cypress/tests/integration/channels/system_console/environment_spec.js b/e2e-tests/cypress/tests/integration/channels/system_console/environment_spec.js index 96e5ac73300..19eff95cda5 100644 --- a/e2e-tests/cypress/tests/integration/channels/system_console/environment_spec.js +++ b/e2e-tests/cypress/tests/integration/channels/system_console/environment_spec.js @@ -256,7 +256,7 @@ describe('Environment', () => { cy.get('#TestS3Connection').scrollIntoView().should('be.visible').within(() => { cy.findByText('Test Connection').should('be.visible').click().wait(TIMEOUTS.ONE_SEC); - waitForAlert('Connection unsuccessful: Unable to connect to S3. Verify your Amazon S3 connection authorization parameters and authentication settings.'); + waitForAlert('Connection unsuccessful: Unable to authenticate against the file storage backend. Verify your credentials and authentication settings.'); }); }); diff --git a/server/channels/api4/system_test.go b/server/channels/api4/system_test.go index cd8c58c9285..f0aef8bd91b 100644 --- a/server/channels/api4/system_test.go +++ b/server/channels/api4/system_test.go @@ -678,12 +678,12 @@ func TestS3TestConnection(t *testing.T) { config.FileSettings.AmazonS3Bucket = new("Wrong_bucket") resp, err = th.SystemAdminClient.TestS3Connection(context.Background(), &config) CheckInternalErrorStatus(t, resp) - CheckErrorID(t, err, "api.file.test_connection_s3_bucket_does_not_exist.app_error") + CheckErrorID(t, err, "api.file.test_connection_no_bucket.app_error") *config.FileSettings.AmazonS3Bucket = "shouldnotcreatenewbucket" resp, err = th.SystemAdminClient.TestS3Connection(context.Background(), &config) CheckInternalErrorStatus(t, resp) - CheckErrorID(t, err, "api.file.test_connection_s3_bucket_does_not_exist.app_error") + CheckErrorID(t, err, "api.file.test_connection_no_bucket.app_error") }) t.Run("with incorrect credentials", func(t *testing.T) { @@ -691,7 +691,7 @@ func TestS3TestConnection(t *testing.T) { *configCopy.FileSettings.AmazonS3AccessKeyId = "invalidaccesskey" resp, err := th.SystemAdminClient.TestS3Connection(context.Background(), &configCopy) CheckInternalErrorStatus(t, resp) - CheckErrorID(t, err, "api.file.test_connection_s3_auth.app_error") + CheckErrorID(t, err, "api.file.test_connection_auth.app_error") }) t.Run("empty file settings", func(t *testing.T) { diff --git a/server/channels/app/file.go b/server/channels/app/file.go index b6ab0a8e97c..f1d3a91e87d 100644 --- a/server/channels/app/file.go +++ b/server/channels/app/file.go @@ -75,14 +75,22 @@ func (a *App) CheckMandatoryS3Fields(settings *model.FileSettings) *model.AppErr } func connectionTestErrorToAppError(connTestErr error) *model.AppError { - switch err := connTestErr.(type) { - case *filestore.S3FileBackendAuthError: - return model.NewAppError("TestConnection", "api.file.test_connection_s3_auth.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - case *filestore.S3FileBackendNoBucketError: - return model.NewAppError("TestConnection", "api.file.test_connection_s3_bucket_does_not_exist.app_error", nil, "", http.StatusInternalServerError).Wrap(err) - default: - return model.NewAppError("TestConnection", "api.file.test_connection.app_error", nil, "", http.StatusInternalServerError).Wrap(connTestErr) + // errors.As (rather than a type switch) so that future wrapping of + // the backend's typed errors does not silently fall through to the + // generic "test_connection" message. + var authErr *filestore.FileBackendAuthError + if errors.As(connTestErr, &authErr) { + // Carry the underlying SDK detail (S3 InvalidAccessKeyId, + // Azure AuthenticationFailed, clock-skew, etc.) into the + // AppError's detail string so the Test Connection toast + // shows admins what actually failed. + return model.NewAppError("TestConnection", "api.file.test_connection_auth.app_error", nil, authErr.Error(), http.StatusInternalServerError).Wrap(authErr) } + var noBucketErr *filestore.FileBackendNoBucketError + if errors.As(connTestErr, &noBucketErr) { + return model.NewAppError("TestConnection", "api.file.test_connection_no_bucket.app_error", nil, noBucketErr.Error(), http.StatusInternalServerError).Wrap(noBucketErr) + } + return model.NewAppError("TestConnection", "api.file.test_connection.app_error", nil, connTestErr.Error(), http.StatusInternalServerError).Wrap(connTestErr) } func (a *App) TestFileStoreConnection() *model.AppError { diff --git a/server/channels/app/server.go b/server/channels/app/server.go index 06528291fc1..518fe6930a1 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -948,8 +948,26 @@ func (s *Server) Start() error { err := s.FileBackend().TestConnection() if err != nil { - if _, ok := err.(*filestore.S3FileBackendNoBucketError); ok { - err = s.FileBackend().(*filestore.S3FileBackend).MakeBucket() + var noBucket *filestore.FileBackendNoBucketError + if errors.As(err, &noBucket) { + // Each backend exposes its own provisioning entry point, so + // dispatch by capability rather than concrete type. New + // backends opt in by implementing this interface; backends + // that do not are reported with the original error so the + // missing-bucket condition surfaces in logs instead of being + // silently swallowed. + type bucketMaker interface { + MakeBucket() error + } + type containerMaker interface { + MakeContainer() error + } + switch b := s.FileBackend().(type) { + case bucketMaker: + err = b.MakeBucket() + case containerMaker: + err = b.MakeContainer() + } } if err != nil { mlog.Error("Problem with file storage settings", mlog.Err(err)) diff --git a/server/cmd/mattermost/commands/db.go b/server/cmd/mattermost/commands/db.go index 17c96c0c45a..f7eb04d8ba2 100644 --- a/server/cmd/mattermost/commands/db.go +++ b/server/cmd/mattermost/commands/db.go @@ -329,6 +329,19 @@ func ConfigToFileBackendSettings(s *model.FileSettings, enableComplianceFeature Directory: *s.Directory, } } + if *s.DriverName == model.ImageDriverAzure { + return filestore.FileBackendSettings{ + DriverName: *s.DriverName, + AzureStorageAccount: *s.AzureStorageAccount, + AzureAccessKey: *s.AzureAccessKey, + AzureContainer: *s.AzureContainer, + AzurePathPrefix: *s.AzurePathPrefix, + AzureEndpoint: *s.AzureEndpoint, + AzureSSL: s.AzureSSL == nil || *s.AzureSSL, + AzureRequestTimeoutMilliseconds: *s.AzureRequestTimeoutMilliseconds, + SkipVerify: skipVerify, + } + } return filestore.FileBackendSettings{ DriverName: *s.DriverName, AmazonS3AccessKeyId: *s.AmazonS3AccessKeyId, diff --git a/server/config/diff.go b/server/config/diff.go index 8e140eb0b22..ea1b8ed9e9d 100644 --- a/server/config/diff.go +++ b/server/config/diff.go @@ -40,6 +40,8 @@ var configSensitivePaths = map[string]bool{ "LdapSettings.BindPassword": true, "FileSettings.PublicLinkSalt": true, "FileSettings.AmazonS3SecretAccessKey": true, + "FileSettings.AzureAccessKey": true, + "FileSettings.ExportAzureAccessKey": true, "SqlSettings.DataSource": true, "SqlSettings.AtRestEncryptKey": true, "SqlSettings.DataSourceReplicas": true, diff --git a/server/config/utils.go b/server/config/utils.go index 1577522564b..63f486fa68a 100644 --- a/server/config/utils.go +++ b/server/config/utils.go @@ -33,6 +33,12 @@ func desanitize(actual, target *model.Config) { if *target.FileSettings.AmazonS3SecretAccessKey == model.FakeSetting { target.FileSettings.AmazonS3SecretAccessKey = actual.FileSettings.AmazonS3SecretAccessKey } + if target.FileSettings.AzureAccessKey != nil && *target.FileSettings.AzureAccessKey == model.FakeSetting { + target.FileSettings.AzureAccessKey = actual.FileSettings.AzureAccessKey + } + if target.FileSettings.ExportAzureAccessKey != nil && *target.FileSettings.ExportAzureAccessKey == model.FakeSetting { + target.FileSettings.ExportAzureAccessKey = actual.FileSettings.ExportAzureAccessKey + } if *target.EmailSettings.SMTPPassword == model.FakeSetting { target.EmailSettings.SMTPPassword = actual.EmailSettings.SMTPPassword diff --git a/server/go.mod b/server/go.mod index 90e70cd2f3d..c9a998b1530 100644 --- a/server/go.mod +++ b/server/go.mod @@ -4,6 +4,8 @@ go 1.26.2 require ( code.sajari.com/docconv/v2 v2.0.0-pre.4 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 github.com/Masterminds/semver/v3 v3.4.0 github.com/avct/uasurfer v0.0.0-20250915105040-a942f6fb6edc github.com/aws/aws-sdk-go-v2 v1.41.5 @@ -25,6 +27,7 @@ require ( github.com/golang-migrate/migrate/v4 v4.19.1 github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 github.com/golang/mock v1.6.0 + github.com/google/uuid v1.6.0 github.com/gorilla/handlers v1.5.2 github.com/gorilla/mux v1.8.1 github.com/gorilla/schema v1.4.1 @@ -71,18 +74,19 @@ require ( github.com/wiggin77/merror v1.0.5 github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c github.com/yuin/goldmark v1.8.2 - golang.org/x/crypto v0.49.0 + golang.org/x/crypto v0.50.0 golang.org/x/image v0.38.0 - golang.org/x/net v0.52.0 + golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 - golang.org/x/sys v0.42.0 - golang.org/x/term v0.41.0 - golang.org/x/text v0.35.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 + golang.org/x/text v0.36.0 gopkg.in/mail.v2 v2.3.1 ) require ( filippo.io/edwards25519 v1.2.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect github.com/JalfResi/justext v0.0.0-20221106200834-be571e3e3052 // indirect github.com/PuerkitoBio/goquery v1.12.0 // indirect github.com/STARRY-S/zip v0.2.3 // indirect @@ -135,7 +139,6 @@ require ( github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/btree v1.1.3 // indirect github.com/google/jsonschema-go v0.4.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect diff --git a/server/go.sum b/server/go.sum index e7c99b8a9e3..117aa7c1369 100644 --- a/server/go.sum +++ b/server/go.sum @@ -12,6 +12,18 @@ filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4 filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/JalfResi/justext v0.0.0-20221106200834-be571e3e3052 h1:8T2zMbhLBbH9514PIQVHdsGhypMrsB4CxwbldKA9sBA= @@ -470,6 +482,8 @@ github.com/pierrec/lz4/v4 v4.1.26 h1:GrpZw1gZttORinvzBdXPUXATeqlJjqUG/D87TKMnhjY github.com/pierrec/lz4/v4 v4.1.26/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -689,8 +703,8 @@ golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= -golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ= @@ -733,8 +747,8 @@ golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -792,8 +806,8 @@ golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -804,8 +818,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= -golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -817,8 +831,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= diff --git a/server/i18n/en.json b/server/i18n/en.json index e3cf0f0c27c..09f24cf195a 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -2372,17 +2372,17 @@ "id": "api.file.test_connection.app_error", "translation": "Unable to access the file storage." }, + { + "id": "api.file.test_connection_auth.app_error", + "translation": "Unable to authenticate against the file storage backend. Verify your credentials and authentication settings." + }, { "id": "api.file.test_connection_email_settings_nil.app_error", - "translation": "Email settings has unset values." + "translation": "Email settings have unset values." }, { - "id": "api.file.test_connection_s3_auth.app_error", - "translation": "Unable to connect to S3. Verify your Amazon S3 connection authorization parameters and authentication settings." - }, - { - "id": "api.file.test_connection_s3_bucket_does_not_exist.app_error", - "translation": "Ensure your Amazon S3 bucket is available, and verify your bucket permissions." + "id": "api.file.test_connection_no_bucket.app_error", + "translation": "The configured bucket or container does not exist. Verify your file storage configuration and permissions." }, { "id": "api.file.test_connection_s3_settings_nil.app_error", @@ -11042,6 +11042,10 @@ "id": "model.config.is_valid.autotranslation.workers.app_error", "translation": "Workers must be between 1 and 64." }, + { + "id": "model.config.is_valid.azure_timeout.app_error", + "translation": "Invalid timeout value {{.Value}}. Should be a positive number." + }, { "id": "model.config.is_valid.cache_type.app_error", "translation": "Cache type must be either lru or redis." @@ -11118,6 +11122,10 @@ "id": "model.config.is_valid.directory.app_error", "translation": "Invalid Local Storage Directory. Must be a non-empty string." }, + { + "id": "model.config.is_valid.directory_traversal.app_error", + "translation": "Path traversal sequences (\"..\") are not allowed in {{.Setting}}. Found \"{{.Value}}\"." + }, { "id": "model.config.is_valid.directory_whitespace.app_error", "translation": "Leading or trailing whitespace detected for {{.Setting}}. Found \"{{.Value}}\"." @@ -11222,9 +11230,13 @@ "id": "model.config.is_valid.export.retention_days_too_low.app_error", "translation": "Invalid value for RetentionDays. Value should be greater than 0" }, + { + "id": "model.config.is_valid.export_azure_timeout.app_error", + "translation": "Invalid timeout value {{.Value}}. Should be a positive number." + }, { "id": "model.config.is_valid.file_driver.app_error", - "translation": "Invalid driver name for file settings. Must be 'local' or 'amazons3'." + "translation": "Invalid driver name for file settings. Must be 'local', 'amazons3', or 'azureblob'." }, { "id": "model.config.is_valid.file_salt.app_error", diff --git a/server/platform/shared/filestore/azurestore.go b/server/platform/shared/filestore/azurestore.go new file mode 100644 index 00000000000..043de68624c --- /dev/null +++ b/server/platform/shared/filestore/azurestore.go @@ -0,0 +1,638 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package filestore + +import ( + "archive/zip" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net/http" + "path" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + "github.com/google/uuid" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + pkgerr "github.com/pkg/errors" +) + +// azureBlockSize is the chunk size used when staging block blob uploads. +// Matches the Azure SDK's default block size for UploadStream and keeps each +// StageBlock call well under the per-block REST limit (4000 MiB). +const azureBlockSize = 4 * 1024 * 1024 + +// AzureFileBackend stores files in Azure Blob Storage. Connections are +// authenticated with a shared key today; Microsoft Entra ID is a follow-up. +type AzureFileBackend struct { + client *azblob.Client + container string + pathPrefix string + timeout time.Duration +} + +func NewAzureFileBackend(settings FileBackendSettings) (*AzureFileBackend, error) { + if err := settings.CheckMandatoryAzureFields(); err != nil { + return nil, err + } + + credential, err := azblob.NewSharedKeyCredential(settings.AzureStorageAccount, settings.AzureAccessKey) + if err != nil { + return nil, pkgerr.Wrap(err, "failed to create azure shared key credential") + } + + scheme := "https" + if !settings.AzureSSL { + scheme = "http" + } + + var serviceURL string + if settings.AzureEndpoint == "" { + // vhost-style production endpoint (Azure commercial cloud). + serviceURL = fmt.Sprintf("%s://%s.blob.core.windows.net/", scheme, settings.AzureStorageAccount) + } else { + // Path-style endpoint where the account is part of the URL path + // rather than the hostname. This covers Azurite and custom hosts + // (reverse proxies, gateways) that expose Azure Blob Storage + // without per-account DNS. Sovereign clouds (Azure Government, + // Azure China) use vhost-style URLs and are not supported via + // this setting; they require their own endpoint plumbing. + serviceURL = fmt.Sprintf("%s://%s/%s/", scheme, strings.Trim(settings.AzureEndpoint, "/"), settings.AzureStorageAccount) + } + + var clientOptions *azblob.ClientOptions + if settings.SkipVerify { + // Mirror the S3 backend: when the admin opts into skipping TLS + // verification, plumb a custom transport into the SDK so the toggle + // actually takes effect for Azure too. + clientOptions = &azblob.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Transport: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + }, + } + } + + client, err := azblob.NewClientWithSharedKeyCredential(serviceURL, credential, clientOptions) + if err != nil { + return nil, pkgerr.Wrap(err, "failed to create azure blob client") + } + + // Config.IsValid rejects non-positive timeouts before they reach this + // constructor, but direct callers (tests, library users that build a + // FileBackendSettings by hand) can still slip a zero or negative value + // in. Fall back to a sane default in that case, and log loudly enough + // for the substitution to show up if it ever happens in production. + timeout := time.Duration(settings.AzureRequestTimeoutMilliseconds) * time.Millisecond + if timeout <= 0 { + mlog.Warn("AzureRequestTimeoutMilliseconds is non-positive; falling back to 30s default", + mlog.Int("value", int(settings.AzureRequestTimeoutMilliseconds))) + timeout = 30 * time.Second + } + + return &AzureFileBackend{ + client: client, + container: settings.AzureContainer, + pathPrefix: settings.AzurePathPrefix, + timeout: timeout, + }, nil +} + +func (b *AzureFileBackend) DriverName() string { + return driverAzure +} + +// prefix joins the configured pathPrefix and the caller-supplied path. +// Using a plain path.Join, a value like "foo/../../secret" can escape +// the prefix entirely, so we compute the join and verify the result is +// the prefix directory itself or a descendant of it. The descendant check +// requires a path-separator boundary so a prefix of "mattermost" does not +// match a sibling like "mattermost-evil/...". If the joined path escapes, +// we fall back to joining the prefix with path.Base, which may drop any +// intermediate directories the caller intended. +func (b *AzureFileBackend) prefix(p string) string { + joined := path.Join(b.pathPrefix, p) + if b.pathPrefix == "" { + return joined + } + + cleanPrefix := strings.TrimSuffix(path.Clean(b.pathPrefix), "/") + if joined == cleanPrefix || strings.HasPrefix(joined, cleanPrefix+"/") { + return joined + } + return path.Join(cleanPrefix, path.Base(p)) +} + +func (b *AzureFileBackend) newBlobClient(p string) *blob.Client { + return b.client.ServiceClient().NewContainerClient(b.container).NewBlobClient(b.prefix(p)) +} + +func (b *AzureFileBackend) newBlockBlobClient(p string) *blockblob.Client { + return b.client.ServiceClient().NewContainerClient(b.container).NewBlockBlobClient(b.prefix(p)) +} + +func (b *AzureFileBackend) newContainerClient() *container.Client { + return b.client.ServiceClient().NewContainerClient(b.container) +} + +// TestConnection probes the configured container and reports the outcome +// using the typed errors shared with the other backends. Container +// creation is deliberately out of scope here - callers (Server.Start) +// decide whether to provision a missing container via MakeContainer. +// That separation keeps a typo in the System Console from silently +// provisioning an unwanted container, and matches the S3 contract where +// TestConnection returns FileBackendNoBucketError and MakeBucket is an +// explicit call. +func (b *AzureFileBackend) TestConnection() error { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + _, err := b.newContainerClient().GetProperties(ctx, nil) + if err == nil { + return nil + } + if bloberror.HasCode(err, bloberror.ContainerNotFound) { + return &FileBackendNoBucketError{Err: pkgerr.Wrapf(err, "azure container %q does not exist", b.container)} + } + if isAzureAuthError(err) { + return &FileBackendAuthError{Err: pkgerr.Wrap(err, "unable to authenticate against azure blob storage")} + } + return pkgerr.Wrap(err, "unable to connect to azure blob storage") +} + +// MakeContainer creates the configured container. Mirrors S3FileBackend.MakeBucket +// so callers can opt into container provisioning explicitly. An already-existing +// container is treated as success so that concurrent boots (two nodes racing +// through TestConnection plus MakeContainer) both converge cleanly. +func (b *AzureFileBackend) MakeContainer() error { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + if _, err := b.newContainerClient().Create(ctx, nil); err != nil { + if bloberror.HasCode(err, bloberror.ContainerAlreadyExists) { + return nil + } + return pkgerr.Wrapf(err, "unable to create azure container %q", b.container) + } + return nil +} + +func (b *AzureFileBackend) Reader(p string) (ReadCloseSeeker, error) { + // Arm the deadline *before* the first network call, then hand the same + // timer to the returned reader on success. The previous code only set up + // the timer on the happy path, which left GetProperties running against a + // no-deadline context. + ctx, cancel := context.WithCancel(context.Background()) + timer := time.AfterFunc(b.timeout, cancel) + blobClient := b.newBlobClient(p) + + props, err := blobClient.GetProperties(ctx, nil) + if err != nil { + timer.Stop() + cancel() + return nil, pkgerr.Wrapf(err, "unable to read file %q", p) + } + if props.ContentLength == nil { + timer.Stop() + cancel() + return nil, pkgerr.Errorf("missing content length for %q", p) + } + + return &azureRangeReader{ + ctx: ctx, + cancel: cancel, + timer: timer, + blobClient: blobClient, + size: *props.ContentLength, + }, nil +} + +func (b *AzureFileBackend) ReadFile(p string) ([]byte, error) { + r, err := b.Reader(p) + if err != nil { + return nil, err + } + defer r.Close() + return io.ReadAll(r) +} + +func (b *AzureFileBackend) FileExists(p string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + _, err := b.newBlobClient(p).GetProperties(ctx, nil) + if err != nil { + if bloberror.HasCode(err, bloberror.BlobNotFound) { + return false, nil + } + return false, pkgerr.Wrapf(err, "unable to check existence of %q", p) + } + return true, nil +} + +func (b *AzureFileBackend) FileSize(p string) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + props, err := b.newBlobClient(p).GetProperties(ctx, nil) + if err != nil { + return 0, pkgerr.Wrapf(err, "unable to get size of %q", p) + } + + return model.SafeDereference(props.ContentLength), nil +} + +func (b *AzureFileBackend) FileModTime(p string) (time.Time, error) { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + props, err := b.newBlobClient(p).GetProperties(ctx, nil) + if err != nil { + return time.Time{}, pkgerr.Wrapf(err, "unable to get modification time of %q", p) + } + + return model.SafeDereference(props.LastModified), nil +} + +// CopyFile copies via StartCopyFromURL and polls the resulting blob's copy +// status until it succeeds, matching the synchronous semantics that the +// FileBackend interface (and the S3 driver via ComposeObject) provides. +func (b *AzureFileBackend) CopyFile(oldPath, newPath string) error { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + src := b.newBlobClient(oldPath).URL() + dst := b.newBlockBlobClient(newPath) + if _, err := dst.StartCopyFromURL(ctx, src, nil); err != nil { + return pkgerr.Wrapf(err, "unable to copy %q to %q", oldPath, newPath) + } + + // Poll until the copy reports success. For server-to-server copies within + // the same account this is typically synchronous, but the API is + // asynchronous in general, so we wait. + for { + props, err := dst.GetProperties(ctx, nil) + if err != nil { + return pkgerr.Wrapf(err, "unable to read copy status for %q", newPath) + } + if props.CopyStatus == nil { + return nil + } + switch *props.CopyStatus { + case blob.CopyStatusTypeSuccess: + return nil + case blob.CopyStatusTypeFailed, blob.CopyStatusTypeAborted: + desc := model.SafeDereference(props.CopyStatusDescription) + return pkgerr.Errorf("azure copy from %q to %q ended in status %q: %q", oldPath, newPath, *props.CopyStatus, desc) + } + select { + case <-ctx.Done(): + return pkgerr.Wrapf(ctx.Err(), "azure copy from %q to %q did not complete in time", oldPath, newPath) + case <-time.After(50 * time.Millisecond): + } + } +} + +func (b *AzureFileBackend) MoveFile(oldPath, newPath string) error { + if err := b.CopyFile(oldPath, newPath); err != nil { + return err + } + return b.RemoveFile(oldPath) +} + +func (b *AzureFileBackend) WriteFile(fr io.Reader, p string) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + return b.WriteFileContext(ctx, fr, p) +} + +// stageBlocks reads fr in azureBlockSize chunks and stages each chunk as a +// block under a fresh ID. Returns the IDs of the newly staged blocks (in +// order) and the total byte count. The caller is responsible for committing +// the block list. +func (b *AzureFileBackend) stageBlocks(ctx context.Context, bb *blockblob.Client, fr io.Reader, p string) ([]string, int64, error) { + buf := make([]byte, azureBlockSize) + var ids []string + var total int64 + + for { + n, err := io.ReadFull(fr, buf) + if n > 0 { + id, idErr := newAzureBlockID() + if idErr != nil { + return nil, 0, pkgerr.Wrap(idErr, "failed to generate azure block id") + } + if _, sbErr := bb.StageBlock(ctx, id, &readSeekNopCloser{Reader: bytes.NewReader(buf[:n])}, nil); sbErr != nil { + return nil, 0, pkgerr.Wrapf(sbErr, "unable to stage block for %q", p) + } + ids = append(ids, id) + total += int64(n) + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } + if err != nil { + return nil, 0, pkgerr.Wrap(err, "failed to read input") + } + } + return ids, total, nil +} + +// WriteFileContext stages the body in fixed-size blocks and commits a fresh +// block list. It deliberately does not use the SDK's UploadStream helper: +// UploadStream's small-payload fast path falls back to single-shot PutBlob, +// which leaves the resulting blob with no committed block list. A subsequent +// AppendFile that calls CommitBlockList on that blob would then clobber its +// content. Routing every WriteFile through StageBlock + CommitBlockList keeps +// AppendFile correct regardless of payload size. +// +// The caller's context governs the entire upload - no inner timeout is added. +// TryWriteFileContext (filesstore.go) relies on this to let long-running +// callers like message-export bulk writes opt out of the per-operation +// timeout that WriteFile applies by default. +func (b *AzureFileBackend) WriteFileContext(ctx context.Context, fr io.Reader, p string) (int64, error) { + bb := b.newBlockBlobClient(p) + blockIDs, total, err := b.stageBlocks(ctx, bb, fr, p) + if err != nil { + return 0, err + } + + if len(blockIDs) == 0 { + // Empty input - still need to materialize an empty blob with a + // committed block list so AppendFile can target it. + id, idErr := newAzureBlockID() + if idErr != nil { + return 0, pkgerr.Wrap(idErr, "failed to generate azure block id") + } + if _, sbErr := bb.StageBlock(ctx, id, &readSeekNopCloser{Reader: bytes.NewReader(nil)}, nil); sbErr != nil { + return 0, pkgerr.Wrapf(sbErr, "unable to stage empty block for %q", p) + } + blockIDs = append(blockIDs, id) + } + + if _, err := bb.CommitBlockList(ctx, blockIDs, nil); err != nil { + return 0, pkgerr.Wrapf(err, "unable to commit block list for %q", p) + } + return total, nil +} + +// AppendFile stages the new chunk as one or more blocks and commits the +// existing committed block list plus the newly staged IDs. Each AppendFile +// call uploads the new bytes exactly once - no re-download, no +// re-concatenate, no re-upload of the prior contents. The S3-style contract +// is preserved: returns an error if the target blob does not yet exist; +// returns the number of bytes appended (not the resulting total size). +// +// Refuses to append to a blob that has content but no committed block list +// (i.e. was uploaded via Put Blob by another tool - Azure portal, azcopy, +// a migration script). Committing a new block list against such a blob +// would replace the existing content with only the appended bytes, so +// failing loud beats silent data loss. +func (b *AzureFileBackend) AppendFile(fr io.Reader, p string) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + bb := b.newBlockBlobClient(p) + + listResp, err := bb.GetBlockList(ctx, blockblob.BlockListTypeCommitted, nil) + if err != nil { + return 0, pkgerr.Wrapf(err, "unable to find file %q to append data", p) + } + + var existingIDs []string + if listResp.BlockList.CommittedBlocks != nil { + for _, blk := range listResp.BlockList.CommittedBlocks { + if blk.Name != nil { + existingIDs = append(existingIDs, *blk.Name) + } + } + } + + if len(existingIDs) == 0 { + props, propsErr := bb.GetProperties(ctx, nil) + if propsErr != nil { + return 0, pkgerr.Wrapf(propsErr, "unable to inspect %q before append", p) + } + if model.SafeDereference(props.ContentLength) > 0 { + return 0, pkgerr.Errorf("refusing to append to %q: blob has content but no committed block list (likely written via Put Blob by another tool)", p) + } + } + + newIDs, total, err := b.stageBlocks(ctx, bb, fr, p) + if err != nil { + return 0, err + } + + if _, err := bb.CommitBlockList(ctx, append(existingIDs, newIDs...), nil); err != nil { + return 0, pkgerr.Wrapf(err, "unable to commit block list for %q", p) + } + return total, nil +} + +func (b *AzureFileBackend) RemoveFile(p string) error { + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + _, err := b.newBlobClient(p).Delete(ctx, nil) + if err != nil && !bloberror.HasCode(err, bloberror.BlobNotFound) { + return pkgerr.Wrapf(err, "unable to remove file %q", p) + } + return nil +} + +func (b *AzureFileBackend) ListDirectory(p string) ([]string, error) { + prefix := b.prefix(p) + if prefix != "" && !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + pager := b.newContainerClient().NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{ + Prefix: &prefix, + }) + + var entries []string + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, pkgerr.Wrapf(err, "unable to list directory %q", p) + } + for _, item := range page.Segment.BlobItems { + if item.Name == nil { + continue + } + name := strings.TrimPrefix(*item.Name, b.pathPrefix) + name = strings.TrimPrefix(name, "/") + entries = append(entries, name) + } + for _, item := range page.Segment.BlobPrefixes { + if item.Name == nil { + continue + } + name := strings.TrimPrefix(*item.Name, b.pathPrefix) + name = strings.TrimPrefix(name, "/") + name = strings.TrimSuffix(name, "/") + entries = append(entries, name) + } + } + return entries, nil +} + +func (b *AzureFileBackend) ListDirectoryRecursively(p string) ([]string, error) { + prefix := b.prefix(p) + if prefix != "" && !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + + ctx, cancel := context.WithTimeout(context.Background(), b.timeout) + defer cancel() + + pager := b.newContainerClient().NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Prefix: &prefix, + }) + + var entries []string + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, pkgerr.Wrapf(err, "unable to list directory %q recursively", p) + } + for _, item := range page.Segment.BlobItems { + if item.Name == nil { + continue + } + name := strings.TrimPrefix(*item.Name, b.pathPrefix) + name = strings.TrimPrefix(name, "/") + entries = append(entries, name) + } + } + return entries, nil +} + +func (b *AzureFileBackend) RemoveDirectory(p string) error { + files, err := b.ListDirectoryRecursively(p) + if err != nil { + return err + } + for _, f := range files { + if err := b.RemoveFile(f); err != nil { + return err + } + } + return nil +} + +func (b *AzureFileBackend) ZipReader(p string, deflate bool) (io.ReadCloser, error) { + method := zip.Store + if deflate { + method = zip.Deflate + } + + pr, pw := io.Pipe() + go func() { + zw := zip.NewWriter(pw) + err := b.writeZip(zw, p, method) + if cerr := zw.Close(); err == nil { + err = cerr + } + pw.CloseWithError(err) + }() + return pr, nil +} + +func (b *AzureFileBackend) writeZip(zw *zip.Writer, p string, method uint16) error { + exists, err := b.FileExists(p) + if err != nil { + return err + } + if exists { + return b.writeZipEntry(zw, p, path.Base(p), method) + } + + files, err := b.ListDirectoryRecursively(p) + if err != nil { + return err + } + prefix := strings.TrimSuffix(p, "/") + "/" + for _, f := range files { + rel := strings.TrimPrefix(f, prefix) + if err := b.writeZipEntry(zw, f, rel, method); err != nil { + return err + } + } + return nil +} + +func (b *AzureFileBackend) writeZipEntry(zw *zip.Writer, blobPath, name string, method uint16) error { + r, err := b.Reader(blobPath) + if err != nil { + return err + } + defer r.Close() + header := &zip.FileHeader{Name: name, Method: method} + header.SetMode(0644) + w, err := zw.CreateHeader(header) + if err != nil { + return err + } + _, err = io.Copy(w, r) + return err +} + +// readSeekNopCloser adapts a Reader+Seeker into a ReadSeekCloser without +// closing the underlying source. The Azure SDK's StageBlock signature +// requires a ReadSeekCloser. +type readSeekNopCloser struct { + io.Reader +} + +func (r *readSeekNopCloser) Seek(offset int64, whence int) (int64, error) { + return r.Reader.(io.Seeker).Seek(offset, whence) +} + +func (r *readSeekNopCloser) Close() error { return nil } + +// newAzureBlockID returns a fresh base64-encoded 16-byte random block ID, +// generated with github.com/google/uuid - the same library azblob uses +// internally for the block IDs it produces in UploadStream. All committed +// blocks in a single blob must share the same decoded length, so callers +// must use this for both WriteFile and AppendFile staging. +// +// Per https://learn.microsoft.com/en-us/rest/api/storageservices/put-block: +// +// For a given blob, all block IDs must be the same length. If a block is +// uploaded with a block ID of a different length than the block IDs for any +// existing uncommitted blocks, the service returns error response code 400 +// (Bad Request). +func newAzureBlockID() (string, error) { + u, err := uuid.NewRandom() + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(u[:]), nil +} + +func isAzureAuthError(err error) bool { + if err == nil { + return false + } + return bloberror.HasCode(err, bloberror.AuthenticationFailed) || + bloberror.HasCode(err, bloberror.AuthorizationFailure) || + bloberror.HasCode(err, bloberror.InvalidAuthenticationInfo) +} diff --git a/server/platform/shared/filestore/azurestore_rangereader.go b/server/platform/shared/filestore/azurestore_rangereader.go new file mode 100644 index 00000000000..7e97a19d012 --- /dev/null +++ b/server/platform/shared/filestore/azurestore_rangereader.go @@ -0,0 +1,160 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package filestore + +import ( + "context" + "io" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + pkgerr "github.com/pkg/errors" +) + +// blobDownloader is the subset of *blob.Client used by azureRangeReader. +// Defined as an interface so tests can substitute a fake without standing up +// a real Azure client. +type blobDownloader interface { + DownloadStream(ctx context.Context, opts *blob.DownloadStreamOptions) (blob.DownloadStreamResponse, error) +} + +// azureRangeReader is a seekable reader over an Azure blob, backed by HTTP +// Range requests. A stream is opened lazily on the first Read at the current +// offset; Seek closes any open stream so the next Read re-opens it from the +// new offset. The context is cancelled either by Close or by a timer set to +// the backend's configured timeout, matching the S3 driver's behavior. +// +// Callers constructing this struct directly must set ctx, cancel and timer; +// the methods below assume all three are non-nil. +type azureRangeReader struct { + ctx context.Context + cancel context.CancelFunc + timer *time.Timer + blobClient blobDownloader + size int64 + offset int64 + body io.ReadCloser +} + +// Compile-time guarantees that azureRangeReader satisfies the interfaces the +// app layer relies on. zip.NewReader requires io.ReaderAt for archive +// readers (e.g. the bulk-import worker), and the import worker also +// type-asserts to a CancelTimeout interface for long-running operations. +var ( + _ ReadCloseSeeker = (*azureRangeReader)(nil) + _ io.ReaderAt = (*azureRangeReader)(nil) +) + +func (r *azureRangeReader) Read(p []byte) (int, error) { + if r.offset >= r.size { + return 0, io.EOF + } + if r.body == nil { + resp, err := r.blobClient.DownloadStream(r.ctx, &blob.DownloadStreamOptions{ + Range: blob.HTTPRange{Offset: r.offset, Count: 0}, + }) + if err != nil { + return 0, pkgerr.Wrap(err, "failed to open azure range stream") + } + r.body = resp.Body + } + n, err := r.body.Read(p) + r.offset += int64(n) + if err == nil { + return n, nil + } + // Close+drop the body so the caller (or a retry) doesn't read more + // from a half-consumed stream, and so Close stays idempotent. + r.body.Close() + r.body = nil + if err == io.EOF && r.offset < r.size { + // The remote stream ended before we reached the blob's content + // length. Surface that as a truncation rather than a clean EOF + // so the caller doesn't accept a partial blob as complete. + return n, io.ErrUnexpectedEOF + } + return n, err +} + +func (r *azureRangeReader) Seek(offset int64, whence int) (int64, error) { + var abs int64 + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = r.offset + offset + case io.SeekEnd: + abs = r.size + offset + default: + return 0, pkgerr.Errorf("invalid whence: %d", whence) + } + if abs < 0 { + return 0, pkgerr.Errorf("negative position: %d", abs) + } + if abs == r.offset { + return abs, nil + } + if r.body != nil { + r.body.Close() + r.body = nil + } + r.offset = abs + return abs, nil +} + +// ReadAt reads len(p) bytes starting at offset off. Each call issues a +// dedicated ranged DownloadStream - calls do not affect the cursor that Read +// uses, matching the io.ReaderAt contract. This is what the bulk-import +// worker needs to feed zip.NewReader on Azure-backed deployments. +func (r *azureRangeReader) ReadAt(p []byte, off int64) (int, error) { + if off < 0 { + return 0, pkgerr.Errorf("negative offset: %d", off) + } + if off >= r.size { + return 0, io.EOF + } + count := int64(len(p)) + if remaining := r.size - off; count > remaining { + count = remaining + } + resp, err := r.blobClient.DownloadStream(r.ctx, &blob.DownloadStreamOptions{ + Range: blob.HTTPRange{Offset: off, Count: count}, + }) + if err != nil { + return 0, pkgerr.Wrap(err, "failed to open azure range stream") + } + defer resp.Body.Close() + n, err := io.ReadFull(resp.Body, p[:count]) + // io.ReadFull returns ErrUnexpectedEOF when the stream terminates + // before count bytes arrive. Only collapse it to io.EOF when we + // actually filled the buffer and consumed the blob to the end - + // otherwise it is a real truncation that needs to surface so + // callers like zip.NewReader do not accept partial content. + if err == io.ErrUnexpectedEOF && int64(n) == count && off+int64(n) == r.size { + return n, io.EOF + } + if err == nil && off+int64(n) == r.size { + return n, io.EOF + } + return n, err +} + +// CancelTimeout stops the timer that bounds this reader's lifetime, so +// long-running consumers (e.g. the bulk-import worker, which can run far +// past the default per-operation timeout) can opt out of the automatic +// cancellation. Returns false if the timer has already fired. +func (r *azureRangeReader) CancelTimeout() bool { + return r.timer.Stop() +} + +func (r *azureRangeReader) Close() error { + if r.timer != nil { + r.timer.Stop() + } + r.cancel() + if r.body != nil { + return r.body.Close() + } + return nil +} diff --git a/server/platform/shared/filestore/azurestore_rangereader_test.go b/server/platform/shared/filestore/azurestore_rangereader_test.go new file mode 100644 index 00000000000..8032fb8062c --- /dev/null +++ b/server/platform/shared/filestore/azurestore_rangereader_test.go @@ -0,0 +1,361 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package filestore + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/stretchr/testify/require" +) + +// trackingReadCloser wraps a Reader and records whether Close was called. +type trackingReadCloser struct { + io.Reader + closed bool +} + +func (t *trackingReadCloser) Close() error { + t.closed = true + return nil +} + +// fakeDownloader serves bytes from an in-memory blob, records every +// DownloadStream call's Range, and hands out trackingReadClosers so tests +// can assert close-on-Seek behavior. An optional err short-circuits responses. +type fakeDownloader struct { + data []byte + calls []blob.HTTPRange + bodies []*trackingReadCloser + err error +} + +func (f *fakeDownloader) DownloadStream(_ context.Context, opts *blob.DownloadStreamOptions) (blob.DownloadStreamResponse, error) { + if f.err != nil { + return blob.DownloadStreamResponse{}, f.err + } + var rng blob.HTTPRange + if opts != nil { + rng = opts.Range + } + f.calls = append(f.calls, rng) + + start := min(max(rng.Offset, 0), int64(len(f.data))) + end := int64(len(f.data)) + if rng.Count > 0 && start+rng.Count < end { + end = start + rng.Count + } + body := &trackingReadCloser{Reader: bytes.NewReader(f.data[start:end])} + f.bodies = append(f.bodies, body) + + return blob.DownloadStreamResponse{ + DownloadResponse: blob.DownloadResponse{Body: body}, + }, nil +} + +// newTestReader returns an azureRangeReader wired to the given fake, with a +// long-lived timer so it never fires during the test. Caller must Close it. +func newTestReader(t *testing.T, fake *fakeDownloader, size int64) *azureRangeReader { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + timer := time.AfterFunc(time.Hour, cancel) + return &azureRangeReader{ + ctx: ctx, + cancel: cancel, + timer: timer, + blobClient: fake, + size: size, + } +} + +func TestRead(t *testing.T) { + t.Run("returns EOF at end of blob without downloading", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("hello")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := r.Seek(0, io.SeekEnd) + require.NoError(t, err) + + n, err := r.Read(make([]byte, 4)) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) + require.Empty(t, fake.calls, "no download should be issued past end of blob") + }) + + t.Run("opens stream at current offset", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("hello world")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := r.Seek(6, io.SeekStart) + require.NoError(t, err) + + buf := make([]byte, 5) + n, err := io.ReadFull(r, buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf)) + + require.Len(t, fake.calls, 1) + require.Equal(t, blob.HTTPRange{Offset: 6, Count: 0}, fake.calls[0]) + }) + + t.Run("sequential reads reuse the open stream", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefghij")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + require.NoError(t, err) + require.Equal(t, "abcd", string(buf)) + + _, err = io.ReadFull(r, buf) + require.NoError(t, err) + require.Equal(t, "efgh", string(buf)) + + require.Len(t, fake.calls, 1, "sequential reads must reuse the open stream") + }) + + t.Run("propagates download errors", func(t *testing.T) { + wantErr := errors.New("boom") + fake := &fakeDownloader{data: []byte("xyz"), err: wantErr} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := r.Read(make([]byte, 1)) + require.ErrorIs(t, err, wantErr) + }) + + t.Run("surfaces truncation when stream EOFs before the blob ends", func(t *testing.T) { + // Promised size is larger than what the fake actually serves, + // so the body eventually returns io.EOF while r.offset < r.size. + // bytes.Reader returns its content + nil first, then 0 + EOF on + // the next call, so we drain the bytes before the truncation + // is observable. + fake := &fakeDownloader{data: []byte("hello")} + r := newTestReader(t, fake, int64(len(fake.data))+10) + defer r.Close() + + buf := make([]byte, 16) + n, err := r.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Second call hits EOF from the body before we've reached r.size, + // so the reader must surface that as a truncation. + n, err = r.Read(buf) + require.Equal(t, 0, n) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + require.Nil(t, r.body, "body must be released after a truncation error") + }) +} + +func TestReadAt(t *testing.T) { + t.Run("reads at the given offset without disturbing the cursor", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefghij")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + // Advance the streaming cursor first. + _, err := io.ReadFull(r, make([]byte, 3)) + require.NoError(t, err) + require.Equal(t, int64(3), r.offset) + + buf := make([]byte, 4) + n, err := r.ReadAt(buf, 5) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, "fghi", string(buf)) + require.Equal(t, int64(3), r.offset, "ReadAt must not touch the streaming offset") + }) + + t.Run("returns io.EOF when the read lands exactly at the end of the blob", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefghij")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + buf := make([]byte, 3) + n, err := r.ReadAt(buf, 7) + require.Equal(t, io.EOF, err) + require.Equal(t, 3, n) + require.Equal(t, "hij", string(buf)) + }) + + t.Run("returns io.EOF when off is past the size", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefghij")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + n, err := r.ReadAt(make([]byte, 4), 100) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) + require.Empty(t, fake.calls, "no download should be issued past end of blob") + }) + + t.Run("rejects negative offsets", func(t *testing.T) { + r := newTestReader(t, &fakeDownloader{}, 10) + defer r.Close() + + _, err := r.ReadAt(make([]byte, 1), -1) + require.Error(t, err) + }) + + t.Run("propagates download errors", func(t *testing.T) { + wantErr := errors.New("boom") + fake := &fakeDownloader{data: []byte("xyz"), err: wantErr} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := r.ReadAt(make([]byte, 1), 0) + require.ErrorIs(t, err, wantErr) + }) + + t.Run("surfaces truncation when stream falls short of the requested count", func(t *testing.T) { + // Promised size exceeds the fake's actual data so ReadFull + // sees the body terminate before count bytes arrived. That + // must surface as ErrUnexpectedEOF, not a clean EOF. + fake := &fakeDownloader{data: []byte("hello")} + r := newTestReader(t, fake, int64(len(fake.data))+5) + defer r.Close() + + buf := make([]byte, 10) + n, err := r.ReadAt(buf, 0) + require.Equal(t, 5, n) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + }) +} + +func TestCancelTimeout(t *testing.T) { + fake := &fakeDownloader{data: []byte("abc")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + require.True(t, r.CancelTimeout(), "first stop should succeed") + require.False(t, r.CancelTimeout(), "second stop must report the timer was already stopped") +} + +func TestSeek(t *testing.T) { + t.Run("absolute from start", func(t *testing.T) { + fake := &fakeDownloader{data: bytes.Repeat([]byte("x"), 32)} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + pos, err := r.Seek(10, io.SeekStart) + require.NoError(t, err) + require.Equal(t, int64(10), pos) + }) + + t.Run("relative to current position", func(t *testing.T) { + fake := &fakeDownloader{data: bytes.Repeat([]byte("x"), 32)} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := r.Seek(10, io.SeekStart) + require.NoError(t, err) + + pos, err := r.Seek(5, io.SeekCurrent) + require.NoError(t, err) + require.Equal(t, int64(15), pos) + }) + + t.Run("relative to end", func(t *testing.T) { + fake := &fakeDownloader{data: bytes.Repeat([]byte("x"), 32)} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + pos, err := r.Seek(-4, io.SeekEnd) + require.NoError(t, err) + require.Equal(t, int64(28), pos) + }) + + t.Run("rejects invalid whence", func(t *testing.T) { + r := newTestReader(t, &fakeDownloader{}, 0) + defer r.Close() + + _, err := r.Seek(0, 99) + require.Error(t, err) + }) + + t.Run("rejects negative absolute position", func(t *testing.T) { + r := newTestReader(t, &fakeDownloader{}, 10) + defer r.Close() + + _, err := r.Seek(-1, io.SeekStart) + require.Error(t, err) + + _, err = r.Seek(-20, io.SeekEnd) + require.Error(t, err) + }) + + t.Run("same offset leaves the open stream untouched", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefgh")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := io.ReadFull(r, make([]byte, 3)) + require.NoError(t, err) + require.Len(t, fake.bodies, 1) + openBody := fake.bodies[0] + + pos, err := r.Seek(3, io.SeekStart) + require.NoError(t, err) + require.Equal(t, int64(3), pos) + require.False(t, openBody.closed, "same-offset seek must not close the open stream") + + _, err = io.ReadFull(r, make([]byte, 3)) + require.NoError(t, err) + require.Len(t, fake.calls, 1, "same-offset seek must not trigger a new download") + }) + + t.Run("different offset closes the open stream and the next read reopens", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdefghij")} + r := newTestReader(t, fake, int64(len(fake.data))) + defer r.Close() + + _, err := io.ReadFull(r, make([]byte, 2)) + require.NoError(t, err) + require.Len(t, fake.bodies, 1) + firstBody := fake.bodies[0] + + _, err = r.Seek(7, io.SeekStart) + require.NoError(t, err) + require.True(t, firstBody.closed, "seek to a new offset must close the open stream") + + buf := make([]byte, 3) + _, err = io.ReadFull(r, buf) + require.NoError(t, err) + require.Equal(t, "hij", string(buf)) + + require.Len(t, fake.calls, 2) + require.Equal(t, int64(7), fake.calls[1].Offset) + }) +} + +func TestClose(t *testing.T) { + t.Run("cancels context and closes the open body", func(t *testing.T) { + fake := &fakeDownloader{data: []byte("abcdef")} + r := newTestReader(t, fake, int64(len(fake.data))) + + _, err := io.ReadFull(r, make([]byte, 3)) + require.NoError(t, err) + require.Len(t, fake.bodies, 1) + + require.NoError(t, r.Close()) + require.True(t, fake.bodies[0].closed) + require.ErrorIs(t, r.ctx.Err(), context.Canceled) + }) + + t.Run("works when no stream was opened", func(t *testing.T) { + r := newTestReader(t, &fakeDownloader{}, 10) + require.NoError(t, r.Close()) + require.ErrorIs(t, r.ctx.Err(), context.Canceled) + }) +} diff --git a/server/platform/shared/filestore/azurestore_test.go b/server/platform/shared/filestore/azurestore_test.go new file mode 100644 index 00000000000..a6dea58aed5 --- /dev/null +++ b/server/platform/shared/filestore/azurestore_test.go @@ -0,0 +1,137 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package filestore + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestAzureFileBackendPrefix(t *testing.T) { + tests := []struct { + name string + prefix string + input string + expected string + }{ + {name: "no prefix, plain path", prefix: "", input: "team/channel/file", expected: "team/channel/file"}, + {name: "no prefix, with dot-dot", prefix: "", input: "../escape", expected: "../escape"}, + {name: "prefix, plain path", prefix: "mattermost", input: "team/channel/file", expected: "mattermost/team/channel/file"}, + {name: "prefix, exact root", prefix: "mattermost", input: "", expected: "mattermost"}, + {name: "prefix, dot-dot escapes", prefix: "mattermost", input: "../escape", expected: "mattermost/escape"}, + {name: "prefix, nested dot-dot escapes", prefix: "mattermost", input: "sub/../../escape", expected: "mattermost/escape"}, + {name: "prefix, dot-dot in middle stays inside", prefix: "mattermost", input: "a/../b", expected: "mattermost/b"}, + {name: "prefix with trailing slash, dot-dot escapes", prefix: "mattermost/", input: "../escape", expected: "mattermost/escape"}, + {name: "prefix boundary collision must not escape", prefix: "mattermost", input: "../mattermost-evil/file", expected: "mattermost/file"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &AzureFileBackend{pathPrefix: tt.prefix} + require.Equal(t, tt.expected, b.prefix(tt.input)) + }) + } +} + +// azuriteWellKnownAccount and azuriteWellKnownKey are Azurite's published +// development credentials. They are not secrets - they are documented in the +// Azurite README and ship hardcoded in every Azurite distribution. +const ( + azuriteWellKnownAccount = "devstoreaccount1" + azuriteWellKnownKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" +) + +// TestAzureFileBackendAppendRefusesNonBlockBlob exercises the safety +// check in AppendFile: when a blob exists with content but no committed +// block list (i.e. it was uploaded via Put Blob by another tool), the +// backend must refuse the append rather than silently destroy the +// existing content. +func TestAzureFileBackendAppendRefusesNonBlockBlob(t *testing.T) { + be := newAzuriteBackend(t) + + path := "append-refusal-test.bin" + t.Cleanup(func() { _ = be.RemoveFile(path) }) + + // Write the blob via the high-level Upload helper, which calls the + // Put Blob REST endpoint and leaves the committed-block list empty. + original := []byte("planted-by-another-tool") + bb := be.newBlockBlobClient(path) + _, err := bb.Upload(context.Background(), nopReadSeekCloser{bytes.NewReader(original)}, nil) + require.NoError(t, err) + + _, err = be.AppendFile(bytes.NewReader([]byte("would-overwrite")), path) + require.Error(t, err) + require.Contains(t, err.Error(), "no committed block list") + + // The original content must still be intact. + got, err := be.ReadFile(path) + require.NoError(t, err) + require.Equal(t, original, got) +} + +// TestAzureFileBackendMakeContainerIdempotent ensures that calling +// MakeContainer twice on the same backend is a no-op the second time. +// Two nodes can race through TestConnection plus MakeContainer at boot; +// the loser must converge instead of returning an error. +func TestAzureFileBackendMakeContainerIdempotent(t *testing.T) { + be := newAzuriteBackend(t) + + require.NoError(t, be.MakeContainer()) + require.NoError(t, be.MakeContainer()) +} + +type nopReadSeekCloser struct { + *bytes.Reader +} + +func (nopReadSeekCloser) Close() error { return nil } + +// newAzuriteBackend builds an Azure backend pointed at the Azurite emulator +// and ensures the container exists. Standalone Azure tests should use this +// instead of calling NewAzureFileBackend + TestConnection directly; the +// shared FileBackendTestSuite handles provisioning itself in SetupTest. +func newAzuriteBackend(t *testing.T) *AzureFileBackend { + t.Helper() + be, err := NewAzureFileBackend(azuriteSettings(t)) + require.NoError(t, err) + + var noBucket *FileBackendNoBucketError + if err := be.TestConnection(); errors.As(err, &noBucket) { + require.NoError(t, be.MakeContainer()) + } else { + require.NoError(t, err) + } + return be +} + +func azuriteSettings(t *testing.T) FileBackendSettings { + t.Helper() + host := os.Getenv("CI_AZURITE_HOST") + if host == "" { + host = "localhost" + } + port := os.Getenv("CI_AZURITE_PORT") + if port == "" { + port = "10000" + } + return FileBackendSettings{ + DriverName: driverAzure, + AzureStorageAccount: azuriteWellKnownAccount, + AzureAccessKey: azuriteWellKnownKey, + AzureContainer: "mattermost-test", + AzureEndpoint: fmt.Sprintf("%s:%s", host, port), + AzureSSL: false, + AzureRequestTimeoutMilliseconds: 30000, + } +} + +func TestAzureFileBackendTestSuite(t *testing.T) { + suite.Run(t, &FileBackendTestSuite{settings: azuriteSettings(t)}) +} diff --git a/server/platform/shared/filestore/errors.go b/server/platform/shared/filestore/errors.go new file mode 100644 index 00000000000..6d034ca6cb4 --- /dev/null +++ b/server/platform/shared/filestore/errors.go @@ -0,0 +1,44 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package filestore + +// FileBackendAuthError is returned when testing a connection and authentication +// against the file storage backend fails. Backends should wrap the underlying +// auth failure in this type so the admin Test Connection flow can surface a +// useful message regardless of which driver is configured. +type FileBackendAuthError struct { + // Err is the underlying driver error, if any. + Err error + // DetailedError is a human-readable message describing the failure. + // Kept for compatibility with the previous S3-specific type. + DetailedError string +} + +func (e *FileBackendAuthError) Error() string { + if e.DetailedError != "" { + return e.DetailedError + } + if e.Err != nil { + return e.Err.Error() + } + return "authentication failed" +} + +func (e *FileBackendAuthError) Unwrap() error { return e.Err } + +// FileBackendNoBucketError is returned when testing a connection and the +// configured bucket / container does not exist. +type FileBackendNoBucketError struct { + // Err is the underlying driver error, if any. + Err error +} + +func (e *FileBackendNoBucketError) Error() string { + if e.Err != nil { + return e.Err.Error() + } + return "no such bucket or container" +} + +func (e *FileBackendNoBucketError) Unwrap() error { return e.Err } diff --git a/server/platform/shared/filestore/filesstore.go b/server/platform/shared/filestore/filesstore.go index 46116579a4d..c7eacb1bb84 100644 --- a/server/platform/shared/filestore/filesstore.go +++ b/server/platform/shared/filestore/filesstore.go @@ -15,6 +15,7 @@ import ( const ( driverS3 = "amazons3" driverLocal = "local" + driverAzure = "azureblob" ) type ReadCloseSeeker interface { @@ -65,6 +66,13 @@ type FileBackendSettings struct { AmazonS3PresignExpiresSeconds int64 AmazonS3UploadPartSizeBytes int64 AmazonS3StorageClass string + AzureStorageAccount string + AzureAccessKey string + AzureContainer string + AzurePathPrefix string + AzureEndpoint string + AzureSSL bool + AzureRequestTimeoutMilliseconds int64 } func NewFileBackendSettingsFromConfig(fileSettings *model.FileSettings, enableComplianceFeature bool, skipVerify bool) FileBackendSettings { @@ -74,6 +82,19 @@ func NewFileBackendSettingsFromConfig(fileSettings *model.FileSettings, enableCo Directory: *fileSettings.Directory, } } + if *fileSettings.DriverName == model.ImageDriverAzure { + return FileBackendSettings{ + DriverName: *fileSettings.DriverName, + AzureStorageAccount: *fileSettings.AzureStorageAccount, + AzureAccessKey: *fileSettings.AzureAccessKey, + AzureContainer: *fileSettings.AzureContainer, + AzurePathPrefix: *fileSettings.AzurePathPrefix, + AzureEndpoint: *fileSettings.AzureEndpoint, + AzureSSL: fileSettings.AzureSSL == nil || *fileSettings.AzureSSL, + AzureRequestTimeoutMilliseconds: *fileSettings.AzureRequestTimeoutMilliseconds, + SkipVerify: skipVerify, + } + } return FileBackendSettings{ DriverName: *fileSettings.DriverName, AmazonS3AccessKeyId: *fileSettings.AmazonS3AccessKeyId, @@ -100,6 +121,19 @@ func NewExportFileBackendSettingsFromConfig(fileSettings *model.FileSettings, en Directory: *fileSettings.ExportDirectory, } } + if *fileSettings.ExportDriverName == model.ImageDriverAzure { + return FileBackendSettings{ + DriverName: *fileSettings.ExportDriverName, + AzureStorageAccount: *fileSettings.ExportAzureStorageAccount, + AzureAccessKey: *fileSettings.ExportAzureAccessKey, + AzureContainer: *fileSettings.ExportAzureContainer, + AzurePathPrefix: *fileSettings.ExportAzurePathPrefix, + AzureEndpoint: *fileSettings.ExportAzureEndpoint, + AzureSSL: fileSettings.ExportAzureSSL == nil || *fileSettings.ExportAzureSSL, + AzureRequestTimeoutMilliseconds: *fileSettings.ExportAzureRequestTimeoutMilliseconds, + SkipVerify: skipVerify, + } + } return FileBackendSettings{ DriverName: *fileSettings.ExportDriverName, AmazonS3AccessKeyId: *fileSettings.ExportAmazonS3AccessKeyId, @@ -133,6 +167,19 @@ func (settings *FileBackendSettings) CheckMandatoryS3Fields() error { return nil } +func (settings *FileBackendSettings) CheckMandatoryAzureFields() error { + if settings.AzureStorageAccount == "" { + return errors.New("missing azure storage account setting") + } + if settings.AzureContainer == "" { + return errors.New("missing azure container setting") + } + if settings.AzureAccessKey == "" { + return errors.New("missing azure access key setting") + } + return nil +} + // NewFileBackend creates a new file backend func NewFileBackend(settings FileBackendSettings) (FileBackend, error) { return newFileBackend(settings, true) @@ -159,6 +206,12 @@ func newFileBackend(settings FileBackendSettings, canBeCloud bool) (FileBackend, return &LocalFileBackend{ directory: settings.Directory, }, nil + case driverAzure: + backend, err := NewAzureFileBackend(settings) + if err != nil { + return nil, errors.Wrap(err, "unable to connect to the azure backend") + } + return backend, nil } return nil, errors.New("no valid filestorage driver found") } diff --git a/server/platform/shared/filestore/filesstore_test.go b/server/platform/shared/filestore/filesstore_test.go index ad5b5b3aa5e..f56182e1258 100644 --- a/server/platform/shared/filestore/filesstore_test.go +++ b/server/platform/shared/filestore/filesstore_test.go @@ -123,11 +123,17 @@ func (s *FileBackendTestSuite) SetupTest() { require.NoError(s.T(), err) s.backend = backend - // This is needed to create the bucket if it doesn't exist. + // This is needed to create the bucket / container if it doesn't exist. err = s.backend.TestConnection() - if _, ok := err.(*S3FileBackendNoBucketError); ok { - s3Backend := s.backend.(*S3FileBackend) - s.NoError(s3Backend.MakeBucket()) + if _, ok := err.(*FileBackendNoBucketError); ok { + switch b := s.backend.(type) { + case *S3FileBackend: + s.NoError(b.MakeBucket()) + case *AzureFileBackend: + s.NoError(b.MakeContainer()) + default: + s.NoError(err) + } } else { s.NoError(err) } @@ -699,7 +705,7 @@ func BenchmarkFileStore(b *testing.B) { // Create bucket if it doesn't exist err = s3Backend.TestConnection() - if _, ok := err.(*S3FileBackendNoBucketError); ok { + if _, ok := err.(*FileBackendNoBucketError); ok { require.NoError(b, s3Backend.(*S3FileBackend).MakeBucket()) } else { require.NoError(b, err) @@ -851,7 +857,7 @@ func BenchmarkS3WriteFile(b *testing.B) { // This is needed to create the bucket if it doesn't exist. err = backend.TestConnection() - if _, ok := err.(*S3FileBackendNoBucketError); ok { + if _, ok := err.(*FileBackendNoBucketError); ok { require.NoError(b, backend.(*S3FileBackend).MakeBucket()) } else { require.NoError(b, err) diff --git a/server/platform/shared/filestore/s3store.go b/server/platform/shared/filestore/s3store.go index 161a2cb3d05..420c6f890eb 100644 --- a/server/platform/shared/filestore/s3store.go +++ b/server/platform/shared/filestore/s3store.go @@ -50,12 +50,13 @@ type S3FileBackend struct { storageClass string } -type S3FileBackendAuthError struct { - DetailedError string -} - -// S3FileBackendNoBucketError is returned when testing a connection and no S3 bucket is found -type S3FileBackendNoBucketError struct{} +// S3FileBackendAuthError and S3FileBackendNoBucketError are aliases for the +// generic backend errors. They are kept so external code (plugins, +// historically-typed consumers) continues to compile. +type ( + S3FileBackendAuthError = FileBackendAuthError + S3FileBackendNoBucketError = FileBackendNoBucketError +) const ( // This is not exported by minio. See: https://github.com/minio/minio-go/issues/1339 @@ -77,14 +78,6 @@ func getContentType(ext string) string { return mimeType } -func (s *S3FileBackendAuthError) Error() string { - return s.DetailedError -} - -func (s *S3FileBackendNoBucketError) Error() string { - return "no such bucket" -} - // NewS3FileBackend returns an instance of an S3FileBackend and determine if we are in Mattermost cloud or not. func NewS3FileBackend(settings FileBackendSettings) (*S3FileBackend, error) { return newS3FileBackend(settings, os.Getenv("MM_CLOUD_FILESTORE_BIFROST") != "") diff --git a/server/public/model/config.go b/server/public/model/config.go index a586861c079..543d751875d 100644 --- a/server/public/model/config.go +++ b/server/public/model/config.go @@ -36,6 +36,7 @@ const ( ImageDriverLocal = "local" ImageDriverS3 = "amazons3" + ImageDriverAzure = "azureblob" DatabaseDriverPostgres = "postgres" @@ -137,6 +138,12 @@ const ( FileSettingsDefaultS3UploadPartSizeBytes = 5 * 1024 * 1024 // 5MB FileSettingsDefaultS3ExportUploadPartSizeBytes = 100 * 1024 * 1024 // 100MB + // maxAzureRequestTimeoutMilliseconds caps the per-request timeout so a + // hung Azure call cannot keep a goroutine open indefinitely. Ten minutes + // is well beyond any realistic single-request workload and matches the + // upper end of Azure SDK retry guidance. + maxAzureRequestTimeoutMilliseconds = 10 * 60 * 1000 + ImportSettingsDefaultDirectory = "./import" ImportSettingsDefaultRetentionDays = 30 @@ -1795,6 +1802,13 @@ type FileSettings struct { AmazonS3RequestTimeoutMilliseconds *int64 `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none AmazonS3UploadPartSizeBytes *int64 `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none AmazonS3StorageClass *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzureStorageAccount *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzureAccessKey *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzureContainer *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzurePathPrefix *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzureEndpoint *string `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none + AzureSSL *bool `access:"environment_file_storage,write_restrictable,cloud_restrictable"` + AzureRequestTimeoutMilliseconds *int64 `access:"environment_file_storage,write_restrictable,cloud_restrictable"` // telemetry: none // Export store settings DedicatedExportStore *bool `access:"environment_file_storage,write_restrictable"` ExportDriverName *string `access:"environment_file_storage,write_restrictable"` @@ -1813,6 +1827,13 @@ type FileSettings struct { ExportAmazonS3PresignExpiresSeconds *int64 `access:"environment_file_storage,write_restrictable"` // telemetry: none ExportAmazonS3UploadPartSizeBytes *int64 `access:"environment_file_storage,write_restrictable"` // telemetry: none ExportAmazonS3StorageClass *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzureStorageAccount *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzureAccessKey *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzureContainer *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzurePathPrefix *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzureEndpoint *string `access:"environment_file_storage,write_restrictable"` // telemetry: none + ExportAzureSSL *bool `access:"environment_file_storage,write_restrictable"` + ExportAzureRequestTimeoutMilliseconds *int64 `access:"environment_file_storage,write_restrictable"` // telemetry: none } func (s *FileSettings) SetDefaults(isUpdate bool) { @@ -1929,6 +1950,34 @@ func (s *FileSettings) SetDefaults(isUpdate bool) { s.AmazonS3StorageClass = new("") } + if s.AzureStorageAccount == nil { + s.AzureStorageAccount = NewPointer("") + } + + if s.AzureAccessKey == nil { + s.AzureAccessKey = NewPointer("") + } + + if s.AzureContainer == nil { + s.AzureContainer = NewPointer("") + } + + if s.AzurePathPrefix == nil { + s.AzurePathPrefix = NewPointer("") + } + + if s.AzureEndpoint == nil { + s.AzureEndpoint = NewPointer("") + } + + if s.AzureSSL == nil { + s.AzureSSL = NewPointer(true) + } + + if s.AzureRequestTimeoutMilliseconds == nil { + s.AzureRequestTimeoutMilliseconds = NewPointer(int64(30000)) + } + if s.DedicatedExportStore == nil { s.DedicatedExportStore = new(false) } @@ -1998,6 +2047,34 @@ func (s *FileSettings) SetDefaults(isUpdate bool) { if s.ExportAmazonS3StorageClass == nil { s.ExportAmazonS3StorageClass = new("") } + + if s.ExportAzureStorageAccount == nil { + s.ExportAzureStorageAccount = NewPointer("") + } + + if s.ExportAzureAccessKey == nil { + s.ExportAzureAccessKey = NewPointer("") + } + + if s.ExportAzureContainer == nil { + s.ExportAzureContainer = NewPointer("") + } + + if s.ExportAzurePathPrefix == nil { + s.ExportAzurePathPrefix = NewPointer("") + } + + if s.ExportAzureEndpoint == nil { + s.ExportAzureEndpoint = NewPointer("") + } + + if s.ExportAzureSSL == nil { + s.ExportAzureSSL = NewPointer(true) + } + + if s.ExportAzureRequestTimeoutMilliseconds == nil { + s.ExportAzureRequestTimeoutMilliseconds = NewPointer(int64(30000)) + } } type EmailSettings struct { @@ -4396,7 +4473,7 @@ func (s *FileSettings) isValid() *AppError { return NewAppError("Config.IsValid", "model.config.is_valid.max_file_size.app_error", nil, "", http.StatusBadRequest) } - if !(*s.DriverName == ImageDriverLocal || *s.DriverName == ImageDriverS3) { + if !(*s.DriverName == ImageDriverLocal || *s.DriverName == ImageDriverS3 || *s.DriverName == ImageDriverAzure) { return NewAppError("Config.IsValid", "model.config.is_valid.file_driver.app_error", nil, "", http.StatusBadRequest) } @@ -4421,6 +4498,10 @@ func (s *FileSettings) isValid() *AppError { return NewAppError("Config.IsValid", "model.config.is_valid.amazons3_timeout.app_error", map[string]any{"Value": *s.MaxImageDecoderConcurrency}, "", http.StatusBadRequest) } + if *s.AzureRequestTimeoutMilliseconds <= 0 || *s.AzureRequestTimeoutMilliseconds > maxAzureRequestTimeoutMilliseconds { + return NewAppError("Config.IsValid", "model.config.is_valid.azure_timeout.app_error", map[string]any{"Value": *s.AzureRequestTimeoutMilliseconds}, "", http.StatusBadRequest) + } + if *s.AmazonS3StorageClass != "" && !slices.Contains([]string{StorageClassStandard, StorageClassReducedRedundancy, StorageClassStandardIA, StorageClassOnezoneIA, StorageClassIntelligentTiering, StorageClassGlacier, StorageClassDeepArchive, StorageClassOutposts, StorageClassGlacierIR, StorageClassSnow, StorageClassExpressOnezone}, *s.AmazonS3StorageClass) { return NewAppError("Config.IsValid", "model.config.is_valid.storage_class.app_error", map[string]any{"Value": *s.AmazonS3StorageClass}, "", http.StatusBadRequest) } @@ -4429,14 +4510,34 @@ func (s *FileSettings) isValid() *AppError { return NewAppError("Config.IsValid", "model.config.is_valid.directory_whitespace.app_error", map[string]any{"Setting": "FileSettings.AmazonS3PathPrefix", "Value": *s.AmazonS3PathPrefix}, "", http.StatusBadRequest) } + if strings.TrimSpace(*s.AzurePathPrefix) != *s.AzurePathPrefix { + return NewAppError("Config.IsValid", "model.config.is_valid.directory_whitespace.app_error", map[string]any{"Setting": "FileSettings.AzurePathPrefix", "Value": *s.AzurePathPrefix}, "", http.StatusBadRequest) + } + + if strings.Contains(*s.AzurePathPrefix, "..") { + return NewAppError("Config.IsValid", "model.config.is_valid.directory_traversal.app_error", map[string]any{"Setting": "FileSettings.AzurePathPrefix", "Value": *s.AzurePathPrefix}, "", http.StatusBadRequest) + } + if *s.ExportAmazonS3StorageClass != "" && !slices.Contains([]string{StorageClassStandard, StorageClassReducedRedundancy, StorageClassStandardIA, StorageClassOnezoneIA, StorageClassIntelligentTiering, StorageClassGlacier, StorageClassDeepArchive, StorageClassOutposts, StorageClassGlacierIR, StorageClassSnow, StorageClassExpressOnezone}, *s.ExportAmazonS3StorageClass) { return NewAppError("Config.IsValid", "model.config.is_valid.storage_class.app_error", map[string]any{"Value": *s.ExportAmazonS3StorageClass}, "", http.StatusBadRequest) } + if *s.ExportAzureRequestTimeoutMilliseconds <= 0 || *s.ExportAzureRequestTimeoutMilliseconds > maxAzureRequestTimeoutMilliseconds { + return NewAppError("Config.IsValid", "model.config.is_valid.export_azure_timeout.app_error", map[string]any{"Value": *s.ExportAzureRequestTimeoutMilliseconds}, "", http.StatusBadRequest) + } + if strings.TrimSpace(*s.ExportAmazonS3PathPrefix) != *s.ExportAmazonS3PathPrefix { return NewAppError("Config.IsValid", "model.config.is_valid.directory_whitespace.app_error", map[string]any{"Setting": "FileSettings.ExportAmazonS3PathPrefix", "Value": *s.ExportAmazonS3PathPrefix}, "", http.StatusBadRequest) } + if strings.TrimSpace(*s.ExportAzurePathPrefix) != *s.ExportAzurePathPrefix { + return NewAppError("Config.IsValid", "model.config.is_valid.directory_whitespace.app_error", map[string]any{"Setting": "FileSettings.ExportAzurePathPrefix", "Value": *s.ExportAzurePathPrefix}, "", http.StatusBadRequest) + } + + if strings.Contains(*s.ExportAzurePathPrefix, "..") { + return NewAppError("Config.IsValid", "model.config.is_valid.directory_traversal.app_error", map[string]any{"Setting": "FileSettings.ExportAzurePathPrefix", "Value": *s.ExportAzurePathPrefix}, "", http.StatusBadRequest) + } + if strings.TrimSpace(*s.ExportDirectory) != *s.ExportDirectory { return NewAppError("Config.IsValid", "model.config.is_valid.directory_whitespace.app_error", map[string]any{"Setting": "FileSettings.ExportDirectory", "Value": *s.ExportDirectory}, "", http.StatusBadRequest) } @@ -5061,6 +5162,14 @@ func (o *Config) Sanitize(pluginManifests []*Manifest, opts *SanitizeOptions) { *o.FileSettings.ExportAmazonS3SecretAccessKey = FakeSetting } + if o.FileSettings.AzureAccessKey != nil && *o.FileSettings.AzureAccessKey != "" { + *o.FileSettings.AzureAccessKey = FakeSetting + } + + if o.FileSettings.ExportAzureAccessKey != nil && *o.FileSettings.ExportAzureAccessKey != "" { + *o.FileSettings.ExportAzureAccessKey = FakeSetting + } + if o.EmailSettings.SMTPPassword != nil && *o.EmailSettings.SMTPPassword != "" { *o.EmailSettings.SMTPPassword = FakeSetting } diff --git a/server/public/model/config_test.go b/server/public/model/config_test.go index 4a63f61c7a0..19676351aca 100644 --- a/server/public/model/config_test.go +++ b/server/public/model/config_test.go @@ -296,6 +296,59 @@ func TestFileSettingsDirectoryWhitespaceValidation(t *testing.T) { } } +func TestFileSettingsAzureRequestTimeoutBounds(t *testing.T) { + cases := []struct { + name string + value int64 + configSetter func(*Config, *int64) + errID string + }{ + {"AzureRequestTimeoutMilliseconds zero", 0, func(cfg *Config, v *int64) { cfg.FileSettings.AzureRequestTimeoutMilliseconds = v }, "model.config.is_valid.azure_timeout.app_error"}, + {"AzureRequestTimeoutMilliseconds negative", -1, func(cfg *Config, v *int64) { cfg.FileSettings.AzureRequestTimeoutMilliseconds = v }, "model.config.is_valid.azure_timeout.app_error"}, + {"AzureRequestTimeoutMilliseconds above ceiling", maxAzureRequestTimeoutMilliseconds + 1, func(cfg *Config, v *int64) { cfg.FileSettings.AzureRequestTimeoutMilliseconds = v }, "model.config.is_valid.azure_timeout.app_error"}, + {"ExportAzureRequestTimeoutMilliseconds zero", 0, func(cfg *Config, v *int64) { cfg.FileSettings.ExportAzureRequestTimeoutMilliseconds = v }, "model.config.is_valid.export_azure_timeout.app_error"}, + {"ExportAzureRequestTimeoutMilliseconds above ceiling", maxAzureRequestTimeoutMilliseconds + 1, func(cfg *Config, v *int64) { cfg.FileSettings.ExportAzureRequestTimeoutMilliseconds = v }, "model.config.is_valid.export_azure_timeout.app_error"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{} + cfg.SetDefaults() + tc.configSetter(cfg, NewPointer(tc.value)) + + err := cfg.FileSettings.isValid() + require.NotNil(t, err) + assert.Equal(t, tc.errID, err.Id) + }) + } +} + +func TestFileSettingsAzurePathPrefixTraversal(t *testing.T) { + cases := []struct { + name string + configSetter func(*Config, *string) + }{ + { + "AzurePathPrefix", + func(cfg *Config, value *string) { cfg.FileSettings.AzurePathPrefix = value }, + }, + { + "ExportAzurePathPrefix", + func(cfg *Config, value *string) { cfg.FileSettings.ExportAzurePathPrefix = value }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{} + cfg.SetDefaults() + tc.configSetter(cfg, NewPointer("../escape")) + + err := cfg.FileSettings.isValid() + require.NotNil(t, err) + assert.Equal(t, "model.config.is_valid.directory_traversal.app_error", err.Id) + }) + } +} + func TestConfigDefaultSignatureAlgorithm(t *testing.T) { c1 := Config{} c1.SetDefaults() From d43dbe972ed35ee419d4eda0d66f259c5d5ff08f Mon Sep 17 00:00:00 2001 From: Julien Tant <785518+JulienTant@users.noreply.github.com> Date: Thu, 14 May 2026 12:37:14 -0700 Subject: [PATCH 06/80] Update Playbooks plugin to v2.9.0 (incl. FIPS) (#36570) --- server/Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/Makefile b/server/Makefile index 5730af0caa2..b6bbdecfe46 100644 --- a/server/Makefile +++ b/server/Makefile @@ -162,7 +162,7 @@ PLUGIN_PACKAGES += mattermost-plugin-calls-v1.11.4 PLUGIN_PACKAGES += mattermost-plugin-github-v2.7.1 PLUGIN_PACKAGES += mattermost-plugin-gitlab-v1.12.2 PLUGIN_PACKAGES += mattermost-plugin-jira-v4.7.0 -PLUGIN_PACKAGES += mattermost-plugin-playbooks-v2.8.1 +PLUGIN_PACKAGES += mattermost-plugin-playbooks-v2.9.0 PLUGIN_PACKAGES += mattermost-plugin-servicenow-v2.4.0 PLUGIN_PACKAGES += mattermost-plugin-zoom-v1.13.0 PLUGIN_PACKAGES += mattermost-plugin-agents-v2.0.3 @@ -178,7 +178,7 @@ PLUGIN_PACKAGES += mattermost-plugin-channel-export-v1.3.0 # download the package from to work. This will no longer be needed when we unify # the way we pre-package FIPS and non-FIPS plugins. ifeq ($(FIPS_ENABLED),true) - PLUGIN_PACKAGES = mattermost-plugin-playbooks-v2.8.1%2Bac0a223-fips + PLUGIN_PACKAGES = mattermost-plugin-playbooks-v2.9.0%2Bdfb5b30-fips PLUGIN_PACKAGES += mattermost-plugin-agents-v2.0.3%2Bcab391a-fips PLUGIN_PACKAGES += mattermost-plugin-boards-v9.2.4%2B5855fe1-fips endif From d4fc0ecb1c352280ecc07aa48386756cf234d6ad Mon Sep 17 00:00:00 2001 From: Jesse Hallam Date: Thu, 14 May 2026 18:29:37 -0300 Subject: [PATCH 07/80] MM-68150: Upgrade golangci-lint to v2.12.2 (#36554) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Simplify invite_people email parsing Replace backwards in-place mutation loop with a straightforward forward filter into a new slice. Extract into parseEmailList so the logic can be unit tested directly. * MM-68150: Upgrade golangci-lint to v2.12.2 Remove //go:fix inline from NewPointer, which is a generic function not yet supported by the inline analyzer, and fix 11 slicesbackward modernize issues flagged by the new version. * MM-68150: Enable all linters by default; disable those with >20 existing issues Switch from opt-in (default: none) to opt-out (default: all) so new linters added to golangci-lint are evaluated automatically. Explicitly disable every linter that has more than 20 pre-existing violations, deferring those for later cleanup. Also disable a handful of linters whose violations are intentional patterns in this codebase (nilerr, dogsled, sqlclosecheck, iotamixing, predeclared, containedctx, iface, gocheckcompilerdirectives, promlinter, goprintffuncname, gomoddirectives). * MM-68150: Fix mirror linter issues Replace Write([]byte(s)) with WriteString(s), and FindIndex([]byte(s)) with FindStringIndex(s), to avoid unnecessary allocations. * MM-68150: Fix nosprintfhostport linter issue Use net.JoinHostPort to construct host:port strings instead of fmt.Sprintf with a manually formatted pattern. * MM-68150: Fix rowserrcheck and sqlclosecheck linter issues Check rows.Err() after iteration loops in schema_dump.go. In the sqlx_wrapper test, defer rows.Close() rather than closing inline. * MM-68150: Fix nilnesserr linter issues — wrong variable in error handlers In 11 places, a stale variable (often the outer err from a prior assignment) was used instead of the freshly-checked error variable (appErr, rowErr, jsonErr, writeErr, esErr). Each produces a typed-nil wrapped in a non-nil interface, silently discarding the real error. * MM-68150: Add i18n string for app.compile_csv_chunks.write_error --------- Co-authored-by: Mattermost Build --- server/.golangci.yml | 94 ++++++++++++++++--- server/Makefile | 2 +- .../api4/outgoing_oauth_connection_test.go | 6 +- server/channels/app/channel.go | 2 +- server/channels/app/platform/web_hub.go | 2 +- server/channels/app/plugin_requests.go | 4 +- server/channels/app/reaction.go | 2 +- server/channels/app/report.go | 2 +- .../slashcommands/command_custom_status.go | 2 +- .../slashcommands/command_invite_people.go | 20 ++-- .../command_invite_people_test.go | 52 ++++++++++ server/channels/app/user.go | 2 +- server/channels/jobs/batch_report_worker.go | 2 +- .../export_users_to_csv.go | 2 +- server/channels/store/sqlstore/job_store.go | 2 +- server/channels/store/sqlstore/schema_dump.go | 9 ++ .../channels/store/sqlstore/session_store.go | 2 +- .../store/sqlstore/sqlx_wrapper_test.go | 5 +- server/channels/utils/license.go | 2 +- server/cmd/mmctl/commands/config_e2e_test.go | 4 +- server/cmd/mmctl/commands/config_test.go | 4 +- .../elasticsearch/opensearch/opensearch.go | 2 +- server/i18n/en.json | 4 + .../services/sharedchannel/sync_send.go | 4 +- server/platform/shared/mail/inbucket.go | 3 +- server/platform/shared/mail/mail_test.go | 4 +- server/public/model/builtin.go | 2 - server/public/model/config.go | 2 +- server/public/model/config_test.go | 6 +- server/public/model/utils_test.go | 2 +- server/public/shared/markdown/inspect.go | 34 +++---- server/public/shared/markdown/paragraph.go | 5 +- tools/mattermost-govet/Makefile | 2 +- 33 files changed, 215 insertions(+), 77 deletions(-) diff --git a/server/.golangci.yml b/server/.golangci.yml index d3f09acf71a..0476dcc59ce 100644 --- a/server/.golangci.yml +++ b/server/.golangci.yml @@ -1,20 +1,84 @@ version: "2" linters: - default: none - enable: - - bidichk - - errcheck - - govet - - ineffassign - - makezero - - misspell - - modernize - - revive - - staticcheck - - unconvert - - unqueryvet - - unused - - whitespace + default: all + disable: + - bodyclose + - canonicalheader + - containedctx # storing context.Context in a struct is an established pattern here + - contextcheck + - cyclop + - depguard + - dogsled # test helpers return many values; blank-heavy destructuring is idiomatic + - dupl + - dupword + - embeddedstructfieldcheck + - err113 + - errchkjson + - errname + - errorlint + - exhaustive + - exhaustruct + - forbidigo + - forcetypeassert + - funcorder + - funlen + - gocheckcompilerdirectives # //go:fix is a valid directive in Go 1.24+; linter doesn't know it yet + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godoclint + - godot + - godox + - gomoddirectives # replace directives in go.mod are intentional forks + - gomodguard # deprecated since v2.12.0; replaced by gomodguard_v2 (enabled via default: all) + - goprintffuncname # Ephemeral → Ephemeralf rename is a plugin API breaking change; deferred + - gosec + - gosmopolitan + - iface # identical job interfaces are intentional — type-safe scheduling without coupling + - inamedparam + - interfacebloat + - intrange + - iotamixing # const blocks intentionally mix iota with explicit values (ABI stability, ASCII) + - ireturn + - lll + - maintidx + - mnd + - musttag + - nakedret + - nestif + - nilerr # intentionally dropping errors is common here (graceful degradation, security non-disclosure, fallbacks) + - nilnil + - nlreturn + - noctx + - noinlineerr + - nolintlint + - nonamedreturns + - paralleltest + - perfsprint + - prealloc + - predeclared # variable named 'copy' is intentional; already suppressed for revive + - promlinter # metric renames are a breaking change; deferred + - protogetter + - recvcheck + - sqlclosecheck # wrapper functions return *sqlx.Rows to callers who close them; not a real leak + - tagalign + - tagliatelle + - testableexamples + - testifylint + - testpackage + - thelper + - tparallel + - unparam + - usestdlibvars + - usetesting + - varnamelen + - wastedassign + - wrapcheck + - wsl + - wsl_v5 settings: govet: disable: diff --git a/server/Makefile b/server/Makefile index b6bbdecfe46..81970e5e986 100644 --- a/server/Makefile +++ b/server/Makefile @@ -328,7 +328,7 @@ golang-versions: ## Install Golang versions used for compatibility testing (e.g. export GO_COMPATIBILITY_TEST_VERSIONS="${GO_COMPATIBILITY_TEST_VERSIONS}" golangci-lint: setup-go-work ## Run golangci-lint on codebase - $(GO) install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.4 + $(GO) install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2 ifeq ($(BUILD_ENTERPRISE_READY),true) $(GOBIN)/golangci-lint run ./... ./public/... $(BUILD_ENTERPRISE_DIR)/... else diff --git a/server/channels/api4/outgoing_oauth_connection_test.go b/server/channels/api4/outgoing_oauth_connection_test.go index f8df04c41f5..0dab7bf7fc1 100644 --- a/server/channels/api4/outgoing_oauth_connection_test.go +++ b/server/channels/api4/outgoing_oauth_connection_test.go @@ -942,7 +942,7 @@ func TestHandlerOutgoingOAuthConnectionUpdate(t *testing.T) { th.AddPermissionToRole(t, model.PermissionManageOutgoingOAuthConnections.Id, model.SystemUserRoleId) body := &bytes.Buffer{} - body.Write([]byte(`{/}`)) + body.WriteString(`{/}`) req, err := http.NewRequest("PUT", "/", body) if err != nil { @@ -990,7 +990,7 @@ func TestHandlerOutgoingOAuthConnectionUpdate(t *testing.T) { th.AddPermissionToRole(t, model.PermissionManageOutgoingOAuthConnections.Id, model.SystemUserRoleId) body := &bytes.Buffer{} - body.Write([]byte(`{"Id": "` + model.NewId() + `", "name": "changed name"}`)) + body.WriteString(`{"Id": "` + model.NewId() + `", "name": "changed name"}`) req, err := http.NewRequest("PUT", "/", body) if err != nil { @@ -1133,7 +1133,7 @@ func TestHandlerOutgoingOAuthConnectionHandlerCreate(t *testing.T) { th.AddPermissionToRole(t, model.PermissionManageOutgoingOAuthConnections.Id, model.SystemUserRoleId) body := &bytes.Buffer{} - body.Write([]byte(`{/}`)) + body.WriteString(`{/}`) req, err := http.NewRequest("POST", "/", body) if err != nil { diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index 3acd05496e7..535de6a5b78 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -4153,7 +4153,7 @@ func (a *App) setSidebarCategoriesForConvertedGroupMessage(rctx request.CTX, gmC channelsCategory := categories.Categories[0] _, appErr = a.UpdateSidebarCategories(rctx, user.Id, gmConversionRequest.TeamID, []*model.SidebarCategoryWithChannels{channelsCategory}) if appErr != nil { - rctx.Logger().Error("Failed to add converted GM to default sidebar category for user", mlog.String("user_id", user.Id), mlog.Err(err)) + rctx.Logger().Error("Failed to add converted GM to default sidebar category for user", mlog.String("user_id", user.Id), mlog.Err(appErr)) } } diff --git a/server/channels/app/platform/web_hub.go b/server/channels/app/platform/web_hub.go index 9ca6c998b8b..e5e1990c435 100644 --- a/server/channels/app/platform/web_hub.go +++ b/server/channels/app/platform/web_hub.go @@ -164,7 +164,7 @@ func (ps *PlatformService) GetHubForUserId(userID string) *Hub { // https://mattermost.atlassian.net/browse/MM-26629. var hash maphash.Hash hash.SetSeed(ps.hashSeed) - _, err := hash.Write([]byte(userID)) + _, err := hash.WriteString(userID) if err != nil { ps.logger.Error("Unable to write userID to hash", mlog.String("userID", userID), mlog.Err(err)) } diff --git a/server/channels/app/plugin_requests.go b/server/channels/app/plugin_requests.go index ef033f4888d..ff535237357 100644 --- a/server/channels/app/plugin_requests.go +++ b/server/channels/app/plugin_requests.go @@ -232,7 +232,7 @@ func (ch *Channels) servePluginRequest(w http.ResponseWriter, r *http.Request, h session, appErr := app.GetSession(token) if appErr != nil { if appErr.StatusCode == http.StatusInternalServerError { - handleInternalServerError(rctx, "Internal server error while loading session", err) + handleInternalServerError(rctx, "Internal server error while loading session", appErr) return } rctx.Logger().Debug("Token in plugin request is invalid. Treating request as unauthenticated", @@ -254,7 +254,7 @@ func (ch *Channels) servePluginRequest(w http.ResponseWriter, r *http.Request, h // If MFA is required and user has not activated it, treat it as unauthenticated if appErr := app.MFARequired(rctx); appErr != nil { if appErr.StatusCode == http.StatusInternalServerError { - handleInternalServerError(rctx, "Internal server error during MFA validation", err) + handleInternalServerError(rctx, "Internal server error during MFA validation", appErr) return } rctx.Logger().Warn("Treating session as unauthenticated since MFA required", diff --git a/server/channels/app/reaction.go b/server/channels/app/reaction.go index 7ece20157d8..126264e95c9 100644 --- a/server/channels/app/reaction.go +++ b/server/channels/app/reaction.go @@ -165,7 +165,7 @@ func (a *App) DeleteReactionForPost(rctx request.CTX, reaction *model.Reaction) restrictDM, appErr := a.CheckIfChannelIsRestrictedDM(rctx, channel) if appErr != nil { - return err + return appErr } if restrictDM { diff --git a/server/channels/app/report.go b/server/channels/app/report.go index 9866469b52f..89b5de0c306 100644 --- a/server/channels/app/report.go +++ b/server/channels/app/report.go @@ -75,7 +75,7 @@ func (a *App) compileCSVChunks(prefix string, numberOfChunks int, headers []stri } _, writeErr := compiledBuf.Write(chunk) if writeErr != nil { - return err + return model.NewAppError("compileCSVChunks", "app.compile_csv_chunks.write_error", nil, "", http.StatusInternalServerError).Wrap(writeErr) } } diff --git a/server/channels/app/slashcommands/command_custom_status.go b/server/channels/app/slashcommands/command_custom_status.go index 914484a5b00..2dd320fa16e 100644 --- a/server/channels/app/slashcommands/command_custom_status.go +++ b/server/channels/app/slashcommands/command_custom_status.go @@ -119,7 +119,7 @@ func GetCustomStatus(message string) *model.CustomStatus { func removeUnicodeSkinTone(unicodeString string) string { skinToneDetectorRegex := regexp.MustCompile("-(1f3fb|1f3fc|1f3fd|1f3fe|1f3ff)") - skinToneLocations := skinToneDetectorRegex.FindIndex([]byte(unicodeString)) + skinToneLocations := skinToneDetectorRegex.FindStringIndex(unicodeString) if len(skinToneLocations) == 0 { return unicodeString diff --git a/server/channels/app/slashcommands/command_invite_people.go b/server/channels/app/slashcommands/command_invite_people.go index 601f0cec10b..011947299b7 100644 --- a/server/channels/app/slashcommands/command_invite_people.go +++ b/server/channels/app/slashcommands/command_invite_people.go @@ -41,6 +41,17 @@ func (*InvitePeopleProvider) GetCommand(a *app.App, T i18n.TranslateFunc) *model } } +func parseEmailList(message string) []string { + var emails []string + for token := range strings.FieldsSeq(message) { + token = strings.Trim(token, ",") + if strings.Contains(token, "@") { + emails = append(emails, token) + } + } + return emails +} + func (*InvitePeopleProvider) DoCommand(a *app.App, rctx request.CTX, args *model.CommandArgs, message string) *model.CommandResponse { if !a.HasPermissionToTeam(rctx, args.UserId, args.TeamId, model.PermissionInviteUser) { return &model.CommandResponse{Text: args.T("api.command_invite_people.permission.app_error"), ResponseType: model.CommandResponseTypeEphemeral} @@ -62,14 +73,7 @@ func (*InvitePeopleProvider) DoCommand(a *app.App, rctx request.CTX, args *model return &model.CommandResponse{ResponseType: model.CommandResponseTypeEphemeral, Text: args.T("api.command.invite_people.email_invitations_off")} } - emailList := strings.Fields(message) - - for i := len(emailList) - 1; i >= 0; i-- { - emailList[i] = strings.Trim(emailList[i], ",") - if !strings.Contains(emailList[i], "@") { - emailList = append(emailList[:i], emailList[i+1:]...) - } - } + emailList := parseEmailList(message) if len(emailList) == 0 { return &model.CommandResponse{ResponseType: model.CommandResponseTypeEphemeral, Text: args.T("api.command.invite_people.no_email")} diff --git a/server/channels/app/slashcommands/command_invite_people_test.go b/server/channels/app/slashcommands/command_invite_people_test.go index 7bba88f465f..33edc3a7d69 100644 --- a/server/channels/app/slashcommands/command_invite_people_test.go +++ b/server/channels/app/slashcommands/command_invite_people_test.go @@ -7,10 +7,62 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" ) +func TestParseEmailList(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "single valid email", + input: "user@example.com", + expected: []string{"user@example.com"}, + }, + { + name: "multiple valid emails", + input: "a@example.com b@example.com", + expected: []string{"a@example.com", "b@example.com"}, + }, + { + name: "trailing commas stripped", + input: "a@example.com, b@example.com,", + expected: []string{"a@example.com", "b@example.com"}, + }, + { + name: "non-email tokens filtered out", + input: "notanemail a@example.com alsoinvalid", + expected: []string{"a@example.com"}, + }, + { + name: "comma immediately after email treated as one token", + input: "a@example.com,b@example.com", + expected: []string{"a@example.com,b@example.com"}, + }, + { + name: "empty input", + input: "", + expected: nil, + }, + { + name: "all tokens invalid", + input: "notanemail alsoinvalid", + expected: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := parseEmailList(tc.input) + require.Equal(t, tc.expected, result) + }) + } +} + func TestInvitePeopleProvider(t *testing.T) { th := setup(t).initBasic(t) diff --git a/server/channels/app/user.go b/server/channels/app/user.go index 360fff44ca6..932b039dd0c 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -1832,7 +1832,7 @@ func (a *App) CreatePasswordRecoveryToken(rctx request.CTX, userID, email string // remove any previously created tokens for user appErr := a.InvalidatePasswordRecoveryTokensForUser(userID) if appErr != nil { - rctx.Logger().Warn("Error while deleting additional user tokens.", mlog.Err(err)) + rctx.Logger().Warn("Error while deleting additional user tokens.", mlog.Err(appErr)) } token := model.NewToken(model.TokenTypePasswordRecovery, string(jsonData)) diff --git a/server/channels/jobs/batch_report_worker.go b/server/channels/jobs/batch_report_worker.go index b0553003467..c8686421313 100644 --- a/server/channels/jobs/batch_report_worker.go +++ b/server/channels/jobs/batch_report_worker.go @@ -106,7 +106,7 @@ func (worker *BatchReportWorker) processChunk(job *model.Job, reportData []model appErr := worker.app.SaveReportChunk(worker.reportFormat, job.Id, fileCount, reportData) if appErr != nil { - return err + return appErr } fileCount++ diff --git a/server/channels/jobs/export_users_to_csv/export_users_to_csv.go b/server/channels/jobs/export_users_to_csv/export_users_to_csv.go index 0c89a929a88..d9655a6cd2a 100644 --- a/server/channels/jobs/export_users_to_csv/export_users_to_csv.go +++ b/server/channels/jobs/export_users_to_csv/export_users_to_csv.go @@ -114,7 +114,7 @@ func getData(app ExportUsersToCSVAppIFace) func(jobData model.StringMap) ([]mode users, appErr := app.GetUsersForReporting(filter) if appErr != nil { - return nil, nil, false, errors.Wrapf(err, "failed to get the next batch (column_value=%v, user_id=%v)", filter.FromColumnValue, filter.FromId) + return nil, nil, false, errors.Wrapf(appErr, "failed to get the next batch (column_value=%v, user_id=%v)", filter.FromColumnValue, filter.FromId) } if len(users) == 0 { diff --git a/server/channels/store/sqlstore/job_store.go b/server/channels/store/sqlstore/job_store.go index cb39c58d13d..8e4471990e6 100644 --- a/server/channels/store/sqlstore/job_store.go +++ b/server/channels/store/sqlstore/job_store.go @@ -447,7 +447,7 @@ func (jss SqlJobStore) Cleanup(expiryTime int64, batchSize int) error { var rowErr error rowsAffected, rowErr = sqlResult.RowsAffected() if rowErr != nil { - return errors.Wrap(err, "unable to delete jobs") + return errors.Wrap(rowErr, "unable to delete jobs") } time.Sleep(jobsCleanupDelay) diff --git a/server/channels/store/sqlstore/schema_dump.go b/server/channels/store/sqlstore/schema_dump.go index 385abeecfa7..312584bc61f 100644 --- a/server/channels/store/sqlstore/schema_dump.go +++ b/server/channels/store/sqlstore/schema_dump.go @@ -177,6 +177,9 @@ func (ss *SqlStore) getTableOptions() (map[string]map[string]string, error) { // Add option to the table tableOptions[tableName][key] = value } + if err := optionsRows.Err(); err != nil { + rErr = multierror.Append(rErr, errors.Wrap(err, "error iterating table options rows")) + } return tableOptions, rErr.ErrorOrNil() } @@ -253,6 +256,9 @@ func (ss *SqlStore) getTableSchemaInformation() (map[string]*model.DatabaseTable }) } } + if err := rows.Err(); err != nil { + rErr = multierror.Append(rErr, errors.Wrap(err, "error iterating schema rows")) + } return tablesMap, tableCollations, rErr.ErrorOrNil() } @@ -298,6 +304,9 @@ func (ss *SqlStore) getTableIndexes() (map[string][]model.DatabaseIndex, error) tableIndexes[tableName] = append(tableIndexes[tableName], index) } + if err := rows.Err(); err != nil { + rErr = multierror.Append(rErr, errors.Wrap(err, "error iterating index rows")) + } return tableIndexes, rErr.ErrorOrNil() } diff --git a/server/channels/store/sqlstore/session_store.go b/server/channels/store/sqlstore/session_store.go index 98e0e280d39..141d5a2acfc 100644 --- a/server/channels/store/sqlstore/session_store.go +++ b/server/channels/store/sqlstore/session_store.go @@ -381,7 +381,7 @@ func (me SqlSessionStore) Cleanup(expiryTime int64, batchSize int64) error { var rowErr error rowsAffected, rowErr = sqlResult.RowsAffected() if rowErr != nil { - return errors.Wrap(err, "unable to delete sessions") + return errors.Wrap(rowErr, "unable to delete sessions") } time.Sleep(sessionsCleanupDelay) diff --git a/server/channels/store/sqlstore/sqlx_wrapper_test.go b/server/channels/store/sqlstore/sqlx_wrapper_test.go index 01449d833c1..cdfec44f029 100644 --- a/server/channels/store/sqlstore/sqlx_wrapper_test.go +++ b/server/channels/store/sqlstore/sqlx_wrapper_test.go @@ -46,7 +46,10 @@ func TestSqlX(t *testing.T) { query := `SELECT pg_sleep(:timeout);` arg := struct{ Timeout int }{Timeout: 2} - _, err = tx.NamedQuery(query, arg) + rows, err := tx.NamedQuery(query, arg) + if rows != nil { + defer rows.Close() + } require.Equal(t, context.DeadlineExceeded, err) require.NoError(t, tx.Commit()) } diff --git a/server/channels/utils/license.go b/server/channels/utils/license.go index d4f97410bad..74930e4895a 100644 --- a/server/channels/utils/license.go +++ b/server/channels/utils/license.go @@ -123,7 +123,7 @@ func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte, var license model.License if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { - return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err) + return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, jsonErr) } return &license, licenseBytes, nil diff --git a/server/cmd/mmctl/commands/config_e2e_test.go b/server/cmd/mmctl/commands/config_e2e_test.go index 4fcc9a1aa8f..e1ab559f995 100644 --- a/server/cmd/mmctl/commands/config_e2e_test.go +++ b/server/cmd/mmctl/commands/config_e2e_test.go @@ -51,7 +51,7 @@ func (s *MmctlE2ETestSuite) TestConfigPatchCmd() { invalidFile, err := os.CreateTemp(os.TempDir(), "invalid_config_*.json") s.Require().Nil(err) - _, err = tmpFile.Write([]byte(configFilePayload)) + _, err = tmpFile.WriteString(configFilePayload) s.Require().Nil(err) defer func() { @@ -212,7 +212,7 @@ rm $1'old'` defer func() { os.Remove(file.Name()) }() - _, err = file.Write([]byte(content)) + _, err = file.WriteString(content) s.Require().Nil(err) s.Require().Nil(file.Close()) s.Require().Nil(os.Chmod(file.Name(), 0700)) diff --git a/server/cmd/mmctl/commands/config_test.go b/server/cmd/mmctl/commands/config_test.go index f20a7cb44a2..b42dba1340b 100644 --- a/server/cmd/mmctl/commands/config_test.go +++ b/server/cmd/mmctl/commands/config_test.go @@ -600,10 +600,10 @@ func (s *MmctlUnitTestSuite) TestConfigPatchCmd() { pluginFile, err := os.CreateTemp(os.TempDir(), "plugin_config_*.json") s.Require().NoError(err) - _, err = tmpFile.Write([]byte(configFilePayload)) + _, err = tmpFile.WriteString(configFilePayload) s.Require().NoError(err) - _, err = pluginFile.Write([]byte(configFilePluginPayload)) + _, err = pluginFile.WriteString(configFilePluginPayload) s.Require().NoError(err) defer func() { diff --git a/server/enterprise/elasticsearch/opensearch/opensearch.go b/server/enterprise/elasticsearch/opensearch/opensearch.go index 30812f572e9..8932244657e 100644 --- a/server/enterprise/elasticsearch/opensearch/opensearch.go +++ b/server/enterprise/elasticsearch/opensearch/opensearch.go @@ -2308,7 +2308,7 @@ func checkMaxVersion(ctx context.Context, client *opensearchapi.Client) (string, major, _, _, esErr := common.GetVersionComponents(resp.Version.Number) if esErr != nil { - return "", 0, model.NewAppError("Opensearch.checkMaxVersion", "ent.elasticsearch.start.parse_server_version.app_error", map[string]any{"Backend": model.ElasticsearchSettingsOSBackend}, "", http.StatusInternalServerError).Wrap(err) + return "", 0, model.NewAppError("Opensearch.checkMaxVersion", "ent.elasticsearch.start.parse_server_version.app_error", map[string]any{"Backend": model.ElasticsearchSettingsOSBackend}, "", http.StatusInternalServerError).Wrap(esErr) } if major > opensearchMaxVersion { diff --git a/server/i18n/en.json b/server/i18n/en.json index 09f24cf195a..15a5aa99763 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -5846,6 +5846,10 @@ "id": "app.compile_csv_chunks.header_error", "translation": "Failed to write CSV headers." }, + { + "id": "app.compile_csv_chunks.write_error", + "translation": "Failed to write CSV data." + }, { "id": "app.compile_report_chunks.unsupported_format", "translation": "Unsupported report format." diff --git a/server/platform/services/sharedchannel/sync_send.go b/server/platform/services/sharedchannel/sync_send.go index fcb77c70ac6..c0f900146fb 100644 --- a/server/platform/services/sharedchannel/sync_send.go +++ b/server/platform/services/sharedchannel/sync_send.go @@ -6,6 +6,7 @@ package sharedchannel import ( "context" "fmt" + "slices" "time" "github.com/mattermost/mattermost/server/public/model" @@ -529,8 +530,7 @@ func (scs *Service) notifyRemoteOffline(posts []*model.Post, rc *model.RemoteClu // range the slice in reverse so the newest posts are visited first; this ensures an ephemeral // get added where it is mostly likely to be seen. - for i := len(posts) - 1; i >= 0; i-- { - post := posts[i] + for _, post := range slices.Backward(posts) { if didNotify := notified[post.UserId]; didNotify { continue } diff --git a/server/platform/shared/mail/inbucket.go b/server/platform/shared/mail/inbucket.go index 851d1dd9a62..ccbdc87e68d 100644 --- a/server/platform/shared/mail/inbucket.go +++ b/server/platform/shared/mail/inbucket.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "os" "strings" @@ -178,5 +179,5 @@ func getInbucketHost() (host string) { if inbucket_port == "" { inbucket_port = "9001" } - return fmt.Sprintf("http://%s:%s", inbucket_host, inbucket_port) + return "http://" + net.JoinHostPort(inbucket_host, inbucket_port) } diff --git a/server/platform/shared/mail/mail_test.go b/server/platform/shared/mail/mail_test.go index 514a56d8a54..33555d4ee5a 100644 --- a/server/platform/shared/mail/mail_test.go +++ b/server/platform/shared/mail/mail_test.go @@ -254,13 +254,13 @@ func TestSendMailUsingConfigAdvanced(t *testing.T) { file1, err := os.CreateTemp("", "*") require.NoError(t, err) defer os.Remove(file1.Name()) - file1.Write([]byte("hello world")) + file1.WriteString("hello world") file1.Close() file2, err := os.CreateTemp("", "*") require.NoError(t, err) defer os.Remove(file2.Name()) - file2.Write([]byte("foo bar")) + file2.WriteString("foo bar") file2.Close() embeddedFiles := map[string]io.Reader{ diff --git a/server/public/model/builtin.go b/server/public/model/builtin.go index 5d52c72e180..1a0d3d55209 100644 --- a/server/public/model/builtin.go +++ b/server/public/model/builtin.go @@ -4,8 +4,6 @@ package model // NewPointer returns a pointer to the object passed. -// -//go:fix inline func NewPointer[T any](t T) *T { return new(t) } // SafeDereference returns the zero value of T if t is nil. diff --git a/server/public/model/config.go b/server/public/model/config.go index 543d751875d..e04fe346529 100644 --- a/server/public/model/config.go +++ b/server/public/model/config.go @@ -5397,7 +5397,7 @@ func structToMapFilteredByTag(t any, typeOfTag, filterTag string) map[string]any switch field.Kind() { case reflect.Struct: value = structToMapFilteredByTag(field.Interface(), typeOfTag, filterTag) - case reflect.Ptr: + case reflect.Pointer: indirectType := field.Elem() if indirectType.Kind() == reflect.Struct { value = structToMapFilteredByTag(indirectType.Interface(), typeOfTag, filterTag) diff --git a/server/public/model/config_test.go b/server/public/model/config_test.go index 19676351aca..41cf045f181 100644 --- a/server/public/model/config_test.go +++ b/server/public/model/config_test.go @@ -32,7 +32,7 @@ func TestConfigDefaults(t *testing.T) { t.Run("nowhere nil when partially initialized", func(t *testing.T) { var recursivelyUninitialize func(*Config, string, reflect.Value) recursivelyUninitialize = func(config *Config, name string, v reflect.Value) { - if v.Type().Kind() == reflect.Ptr { + if v.Type().Kind() == reflect.Pointer { // Ignoring these 2 settings. // TODO: remove them completely in v8.0. if name == "config.ElasticsearchSettings.BulkIndexingTimeWindowSeconds" || @@ -2937,10 +2937,10 @@ func TestConfigAccessTagsMapToValidPermissions(t *testing.T) { fieldPath := path + "." + field.Name elemType := field.Type - if elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Slice { + if elemType.Kind() == reflect.Pointer || elemType.Kind() == reflect.Slice { elemType = elemType.Elem() } - if elemType.Kind() == reflect.Ptr { + if elemType.Kind() == reflect.Pointer { elemType = elemType.Elem() } if elemType.Kind() == reflect.Struct { diff --git a/server/public/model/utils_test.go b/server/public/model/utils_test.go index 67a39f1ccb4..4b04b26f92c 100644 --- a/server/public/model/utils_test.go +++ b/server/public/model/utils_test.go @@ -1087,7 +1087,7 @@ func checkNowhereNil(t *testing.T, name string, value any) bool { v := reflect.ValueOf(value) switch v.Type().Kind() { - case reflect.Ptr: + case reflect.Pointer: // Ignoring these 2 settings. // TODO: remove them completely in v8.0. if name == "config.ElasticsearchSettings.BulkIndexingTimeWindowSeconds" || diff --git a/server/public/shared/markdown/inspect.go b/server/public/shared/markdown/inspect.go index 151b9590244..b3eb2b6d5ee 100644 --- a/server/public/shared/markdown/inspect.go +++ b/server/public/shared/markdown/inspect.go @@ -3,6 +3,8 @@ package markdown +import "slices" + const ( // Assuming 64k maxSize of a post which can be stored in DB. // Allow scanning upto twice(arbitrary value) the post size. @@ -58,20 +60,20 @@ func InspectBlock(block Block, f func(Block) bool) { switch v := block.(type) { case *Document: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *List: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *ListItem: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *BlockQuote: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } } } @@ -103,20 +105,20 @@ func InspectInline(inline Inline, f func(Inline) bool) { switch v := inline.(type) { case *InlineImage: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *InlineLink: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *ReferenceImage: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } case *ReferenceLink: - for i := len(v.Children) - 1; i >= 0; i-- { - stack = append(stack, v.Children[i]) + for _, v0 := range slices.Backward(v.Children) { + stack = append(stack, v0) } } } diff --git a/server/public/shared/markdown/paragraph.go b/server/public/shared/markdown/paragraph.go index aef01b5e151..5a2012d56de 100644 --- a/server/public/shared/markdown/paragraph.go +++ b/server/public/shared/markdown/paragraph.go @@ -4,6 +4,7 @@ package markdown import ( + "slices" "strings" ) @@ -51,8 +52,8 @@ func (b *Paragraph) Close() { b.Text = remaining } - for i := len(b.Text) - 1; i >= 0; i-- { - b.Text[i] = trimRightSpace(b.markdown, b.Text[i]) + for i, v := range slices.Backward(b.Text) { + b.Text[i] = trimRightSpace(b.markdown, v) if b.Text[i].Position < b.Text[i].End { break } diff --git a/tools/mattermost-govet/Makefile b/tools/mattermost-govet/Makefile index 05254a180c6..c05d7fd23a9 100644 --- a/tools/mattermost-govet/Makefile +++ b/tools/mattermost-govet/Makefile @@ -12,7 +12,7 @@ clean: rm -rf dist golangci-lint: - $(GO) install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.4 + $(GO) install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2 $(GOBIN)/golangci-lint run ./... check-style: golangci-lint From 51fd952ae6fe02555b7ed1b7b05d9e97f1e90c09 Mon Sep 17 00:00:00 2001 From: Vishal Date: Fri, 15 May 2026 16:33:54 +0530 Subject: [PATCH 08/80] MM-67771: Update Report a Problem to email flow (#35900) * MM-67771 Update Report a Problem to email flow for licensed servers Change the default "Report a Problem" behavior for licensed servers to open a mailto link to reportaproblem@mattermost.com with pre-filled metadata instead of redirecting to the support portal. Unlicensed servers continue to redirect to the troubleshooting forums. Admin console help text is now license-aware with separate descriptions for each plan type. * Add isFreeEdition check for Report a Problem flow Treat both unlicensed servers and licensed servers with entry SKU as free edition. This affects the Report a Problem default behavior (forum redirect vs mailto) and the admin console help text shown. - Add isFreeEdition to general.ts selectors and admin_definition_helpers - Add SKUEntry constant to general constants - Reuse isFreeEdition in product_menu.tsx - Add entry SKU test case for report_a_problem * Add the link to forums for free edition * Add permission and restricted-mode guards to ReportAProblemType dropdown The ReportAProblemType dropdown was missing the write-permission check and RestrictSystemAdmin guard that all other fields in the section have. --- .../link_customization_cloud_spec.js | 2 +- .../customization_not_cloud_spec.js | 10 +- e2e-tests/cypress/tests/utils/constants.js | 2 - .../admin_console/admin_definition.tsx | 80 +++++------ .../admin_definition_helpers.tsx | 1 + .../product_menu/product_menu.tsx | 7 +- webapp/channels/src/i18n/en.json | 13 +- .../mattermost-redux/src/constants/general.ts | 1 + .../src/selectors/entities/general.ts | 5 + .../entities/report_a_problem.test.ts | 128 +++++++++++++++++- .../selectors/entities/report_a_problem.ts | 49 ++++++- .../src/utils/browser_info.ts | 11 ++ 12 files changed, 237 insertions(+), 72 deletions(-) diff --git a/e2e-tests/cypress/tests/integration/channels/system_console/site_configuration/link_customization_cloud_spec.js b/e2e-tests/cypress/tests/integration/channels/system_console/site_configuration/link_customization_cloud_spec.js index 47012ca7261..88732976cd2 100644 --- a/e2e-tests/cypress/tests/integration/channels/system_console/site_configuration/link_customization_cloud_spec.js +++ b/e2e-tests/cypress/tests/integration/channels/system_console/site_configuration/link_customization_cloud_spec.js @@ -28,7 +28,7 @@ describe('SupportSettings', () => { [ {text: 'Ask the community', link: SupportSettings.ASK_COMMUNITY_LINK}, {text: 'Mattermost user guide', link: SupportSettings.MATTERMOST_USER_GUIDE}, - {text: 'Report a problem', link: SupportSettings.REPORT_A_PROBLEM_LINK}, + {text: 'Report a problem', link: 'mailto:reportaproblem@mattermost.com'}, {text: 'Keyboard shortcuts'}, ].forEach(({text, link}) => { if (link) { diff --git a/e2e-tests/cypress/tests/integration/channels/system_console/ui_and_api/customization_not_cloud_spec.js b/e2e-tests/cypress/tests/integration/channels/system_console/ui_and_api/customization_not_cloud_spec.js index 60e0d1ad6b2..07cad0354b4 100644 --- a/e2e-tests/cypress/tests/integration/channels/system_console/ui_and_api/customization_not_cloud_spec.js +++ b/e2e-tests/cypress/tests/integration/channels/system_console/ui_and_api/customization_not_cloud_spec.js @@ -28,14 +28,14 @@ describe('Customization', () => { }); it('MM-T1214 - Can change Report a Problem Link setting', () => { - // * Verify Report a Problem link label is visible and matches the text - cy.findByTestId('SupportSettings.ReportAProblemLinklabel').scrollIntoView().should('be.visible').and('have.text', 'Report a Problem Link:'); + // # Select 'Custom link' from the Report a Problem dropdown + cy.findByTestId('SupportSettings.ReportAProblemTypedropdown').scrollIntoView().select('Custom link'); - // * Verify Report a Problem link input box has default value. The default value depends on the setup before running the test. - cy.findByTestId('SupportSettings.ReportAProblemLinkinput').should('have.value', origConfig.SupportSettings.ReportAProblemLink); + // * Verify Report a Problem link label is visible and matches the text + cy.findByTestId('SupportSettings.ReportAProblemLinklabel').scrollIntoView().should('be.visible').and('have.text', 'Custom Report a Problem Link:'); // * Verify Report a Problem link help text is visible and matches the text - cy.findByTestId('SupportSettings.ReportAProblemLinkhelp-text').find('span').should('be.visible').and('have.text', 'The URL for the Report a Problem link in the Help Menu. If this field is empty, the link is removed from the Help Menu.'); + cy.findByTestId('SupportSettings.ReportAProblemLinkhelp-text').find('span').should('be.visible').and('have.text', 'Enter the URL that users will be directed to when they choose "Report a Problem".'); // # Enter a problem link const reportAProblemLink = 'https://mattermost.com/pl/report-a-bug'; diff --git a/e2e-tests/cypress/tests/utils/constants.js b/e2e-tests/cypress/tests/utils/constants.js index 321e798b74b..19518fec585 100644 --- a/e2e-tests/cypress/tests/utils/constants.js +++ b/e2e-tests/cypress/tests/utils/constants.js @@ -6,7 +6,6 @@ export const ABOUT_LINK = 'https://mattermost.com/pl/about-mattermost'; export const ASK_COMMUNITY_LINK = 'https://mattermost.com/pl/default-ask-mattermost-community/'; export const HELP_LINK = 'https://mattermost.com/pl/help/'; export const PRIVACY_POLICY_LINK = 'https://mattermost.com/pl/privacy-policy/'; -export const REPORT_A_PROBLEM_LINK = 'https://mattermost.com/pl/report-a-bug'; export const TERMS_OF_SERVICE_LINK = 'https://mattermost.com/pl/terms-of-use/'; export const MATTERMOST_USER_GUIDE = 'https://docs.mattermost.com/guides/use-mattermost.html'; @@ -24,7 +23,6 @@ export const SupportSettings = { ASK_COMMUNITY_LINK, HELP_LINK, PRIVACY_POLICY_LINK, - REPORT_A_PROBLEM_LINK, TERMS_OF_SERVICE_LINK, MATTERMOST_USER_GUIDE, }; diff --git a/webapp/channels/src/components/admin_console/admin_definition.tsx b/webapp/channels/src/components/admin_console/admin_definition.tsx index 29b930deec1..a62ea3cd26c 100644 --- a/webapp/channels/src/components/admin_console/admin_definition.tsx +++ b/webapp/channels/src/components/admin_console/admin_definition.tsx @@ -219,6 +219,25 @@ const SAML_SETTINGS_CANONICAL_ALGORITHM_C14N11 = 'Canonical1.1'; // - remove_action: An store action to remove the file. // - fileType: A list of extensions separated by ",". E.g. ".jpg,.png,.gif". +const reportAProblemTypeOptions = [ + { + display_name: defineMessage({id: 'admin.support.problemType.defaultLink', defaultMessage: 'Default'}), + value: 'default', + }, + { + display_name: defineMessage({id: 'admin.support.problemType.email', defaultMessage: 'Email address'}), + value: 'email', + }, + { + display_name: defineMessage({id: 'admin.support.problemType.customLink', defaultMessage: 'Custom link'}), + value: 'link', + }, + { + display_name: defineMessage({id: 'admin.support.problemType.hide', defaultMessage: 'Hide link'}), + value: 'hidden', + }, +]; + const adminDefinitionMessages = defineMessages({ data_retention_title: {id: 'admin.data_retention.title', defaultMessage: 'Data Retention Policy'}, ip_filtering_title: {id: 'admin.sidebar.ip_filtering', defaultMessage: 'IP Filtering'}, @@ -2492,57 +2511,32 @@ const AdminDefinition: AdminDefinitionType = { type: 'dropdown', key: 'SupportSettings.ReportAProblemType', label: defineMessage({id: 'admin.support.reportAProblemTypeTitle', defaultMessage: 'Report a Problem:'}), - help_text: defineMessage({id: 'admin.support.reportAProblemTypeDescription', defaultMessage: 'Select how the ‘Report a Problem’ option behaves. Choosing ‘Custom link’ or ‘Email address’ allows you to provide a URL or address in the next field. ‘Hide link’ removes the ‘Report a Problem’ option from the app.'}), - options: [ - { - display_name: defineMessage({id: 'admin.support.problemType.defaultLink', defaultMessage: 'Default link'}), - value: 'default', - }, - { - display_name: defineMessage({id: 'admin.support.problemType.email', defaultMessage: 'Email address'}), - value: 'email', - }, - { - display_name: defineMessage({id: 'admin.support.problemType.customLink', defaultMessage: 'Custom link'}), - value: 'link', - }, - { - display_name: defineMessage({id: 'admin.support.problemType.hide', defaultMessage: 'Hide link'}), - value: 'hidden', - }, - ], + help_text: defineMessage({id: 'admin.support.reportAProblemTypeDescriptionLicensed', defaultMessage: 'By default, selecting "Report a Problem" from the help menu opens a pre-filled email draft to the Mattermost technical support team. You may provide a custom URL or email address for end user support by choosing "Custom link" or "Email address". "Hide link" removes the "Report a Problem" option from the app.'}), + isDisabled: it.not(it.userHasWritePermissionOnResource(RESOURCE_KEYS.SITE.CUSTOMIZATION)), + isHidden: it.any( + it.isFreeEdition, + it.configIsTrue('ExperimentalSettings', 'RestrictSystemAdmin'), + ), + options: reportAProblemTypeOptions, }, { - type: 'text', - key: 'defaultLicensedReportAProblemLink', - label: defineMessage({id: 'admin.support.reportAProblemDefaultLinkTitle', defaultMessage: 'Default Report a Problem Link:'}), - help_text: defineMessage({id: 'admin.support.reportAProblemDefaultLinkDescription', defaultMessage: 'Users will be directed to this link when they choose ‘Report a Problem’.'}), - default: 'https://mattermost.com/pl/report_a_problem_licensed', - isDisabled: it.all(), + type: 'dropdown', + key: 'SupportSettings.ReportAProblemType', + label: defineMessage({id: 'admin.support.reportAProblemTypeTitle', defaultMessage: 'Report a Problem:'}), + help_text: defineMessage({id: 'admin.support.reportAProblemTypeDescriptionUnlicensed', defaultMessage: 'By default, selecting "Report a Problem" from the help menu opens the [Mattermost troubleshooting forums](https://mattermost.com/pl/report_a_problem_unlicensed). You may provide a custom URL or email address for end user support by choosing "Custom link" or "Email address". "Hide link" removes the "Report a Problem" option from the app.'}), + help_text_markdown: true, + isDisabled: it.not(it.userHasWritePermissionOnResource(RESOURCE_KEYS.SITE.CUSTOMIZATION)), isHidden: it.any( + it.not(it.isFreeEdition), it.configIsTrue('ExperimentalSettings', 'RestrictSystemAdmin'), - it.not(it.stateMatches('SupportSettings.ReportAProblemType', /default/)), - it.not(it.licensed), - ), - }, - { - type: 'text', - key: 'defaultUnlicensedReportAProblemLink', - label: defineMessage({id: 'admin.support.reportAProblemDefaultLinkTitle', defaultMessage: 'Default Report a Problem Link:'}), - help_text: defineMessage({id: 'admin.support.reportAProblemDefaultLinkDescription', defaultMessage: 'Users will be directed to this link when they choose ‘Report a Problem’.'}), - default: 'https://mattermost.com/pl/report_a_problem_unlicensed', - isDisabled: it.all(), - isHidden: it.any( - it.configIsTrue('ExperimentalSettings', 'RestrictSystemAdmin'), - it.not(it.stateMatches('SupportSettings.ReportAProblemType', /default/)), - it.licensed, ), + options: reportAProblemTypeOptions, }, { type: 'text', key: 'SupportSettings.ReportAProblemLink', label: defineMessage({id: 'admin.support.reportAProblemLinkTitle', defaultMessage: 'Custom Report a Problem Link:'}), - help_text: defineMessage({id: 'admin.support.reportAProblemLinkDescription', defaultMessage: 'Enter the URL that users will be directed to when they choose ‘Report a Problem’.'}), + help_text: defineMessage({id: 'admin.support.reportAProblemLinkDescription', defaultMessage: 'Enter the URL that users will be directed to when they choose "Report a Problem".'}), isDisabled: it.any( it.not(it.userHasWritePermissionOnResource(RESOURCE_KEYS.SITE.CUSTOMIZATION)), ), @@ -2561,7 +2555,7 @@ const AdminDefinition: AdminDefinitionType = { type: 'text', key: 'SupportSettings.ReportAProblemMail', label: defineMessage({id: 'admin.support.reportAProblemEmailTitle', defaultMessage: 'Report a Problem Email Address:'}), - help_text: defineMessage({id: 'admin.support.reportAProblemEmailDescription', defaultMessage: 'Enter the email address that users will be prompted to send a message to when they choose ‘Report a Problem’.'}), + help_text: defineMessage({id: 'admin.support.reportAProblemEmailDescription', defaultMessage: 'Enter the email address that users will be prompted to send a message to when they choose "Report a Problem".'}), isDisabled: (it.not(it.userHasWritePermissionOnResource(RESOURCE_KEYS.SITE.CUSTOMIZATION))), isHidden: it.any( it.configIsTrue('ExperimentalSettings', 'RestrictSystemAdmin'), @@ -2578,7 +2572,7 @@ const AdminDefinition: AdminDefinitionType = { type: 'bool', key: 'SupportSettings.AllowDownloadLogs', label: defineMessage({id: 'admin.support.problemAllowDownloadTitle', defaultMessage: 'Allow Mobile App Log Downloads:'}), - help_text: defineMessage({id: 'admin.support.problemAllowDownloadDescription', defaultMessage: 'When enabled, users can download app logs for troubleshooting. If a ‘Report a Problem’ link is shown, logs can be downloaded as part of that flow; if the ‘Report a Problem’ link is hidden, logs remain accessible as a separate option.'}), + help_text: defineMessage({id: 'admin.support.problemAllowDownloadDescription', defaultMessage: 'When enabled, users can download app logs for troubleshooting. If a "Report a Problem" link is shown, logs can be downloaded as part of that flow; if the "Report a Problem" link is hidden, logs remain accessible as a separate option.'}), isDisabled: it.not(it.userHasWritePermissionOnResource(RESOURCE_KEYS.SITE.CUSTOMIZATION)), }, { diff --git a/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx b/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx index 9617bcbc180..e4c15053bd5 100644 --- a/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx +++ b/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx @@ -48,6 +48,7 @@ export const it = { configContains: (group: keyof Partial, setting: string, word: string) => (config: Partial) => Boolean((config[group] as any)?.[setting]?.includes(word)), enterpriseReady: (config: Partial, state: any, license?: ClientLicense, enterpriseReady?: boolean) => Boolean(enterpriseReady), licensed: (config: Partial, state: any, license?: ClientLicense) => license?.IsLicensed === 'true', + isFreeEdition: (config: Partial, state: any, license?: ClientLicense) => license?.IsLicensed !== 'true' || license?.SkuShortName === LicenseSkus.Entry, cloudLicensed: (config: Partial, state: any, license?: ClientLicense) => Boolean(license?.IsLicensed && isCloudLicense(license)), licensedForFeature: (feature: string) => (config: Partial, state: any, license?: ClientLicense) => Boolean(license?.IsLicensed && license[feature] === 'true'), licensedForSku: (skuName: string) => (config: Partial, state: any, license?: ClientLicense) => Boolean(license?.IsLicensed && license.SkuShortName === skuName), diff --git a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu.tsx b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu.tsx index ddda1781a4b..a01733fee7c 100644 --- a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu.tsx +++ b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu.tsx @@ -10,7 +10,7 @@ import { ProductsIcon, } from '@mattermost/compass-icons/components'; -import {getLicense} from 'mattermost-redux/selectors/entities/general'; +import {isFreeEdition as isFreeEditionSelector} from 'mattermost-redux/selectors/entities/general'; import {setProductMenuSwitcherOpen} from 'actions/views/product_menu'; import {isSwitcherOpen} from 'selectors/views/product_menu'; @@ -24,7 +24,6 @@ import { import Menu from 'components/widgets/menu/menu'; import MenuWrapper from 'components/widgets/menu/menu_wrapper'; -import {LicenseSkus} from 'utils/constants'; import {useCurrentProductId, useProducts, isChannels} from 'utils/products'; import ProductBranding from './product_branding'; @@ -77,7 +76,7 @@ const ProductMenu = (): JSX.Element => { const switcherOpen = useSelector(isSwitcherOpen); const menuRef = useRef(null); const currentProductID = useCurrentProductId(); - const license = useSelector(getLicense); + const isFreeEdition = useSelector(isFreeEditionSelector); const handleClick = () => dispatch(setProductMenuSwitcherOpen(!switcherOpen)); @@ -114,8 +113,6 @@ const ProductMenu = (): JSX.Element => { ); }); - const isFreeEdition = license.IsLicensed === 'false' || license.SkuShortName === LicenseSkus.Entry; - return (
@@ -1623,7 +1624,7 @@ export class UserSettingsGeneralTab extends PureComponent { max = ( { key={'settingItem_' + attribute.id} active={active} areAllSectionsInactive={this.props.activeSection === ''} - title={attribute.name} + title={getUserPropertyFieldLabel(attribute)} describe={describe} section={sectionName} updateSection={this.updateSection} diff --git a/webapp/channels/src/i18n/en.json b/webapp/channels/src/i18n/en.json index 12ab37a113a..12d161efd30 100644 --- a/webapp/channels/src/i18n/en.json +++ b/webapp/channels/src/i18n/en.json @@ -3129,7 +3129,6 @@ "admin.system_properties.user_properties.dotmenu.ad_ldap.modal.helpText": "The attribute in the AD/LDAP server used to sync as a custom attribute in user's profile in Mattermost.", "admin.system_properties.user_properties.dotmenu.delete.label": "Delete attribute", "admin.system_properties.user_properties.dotmenu.duplicate.label": "Duplicate attribute", - "admin.system_properties.user_properties.dotmenu.duplicate.name_copy": "{fieldName} (copy)", "admin.system_properties.user_properties.dotmenu.editable_by_users.label": "Editable by users", "admin.system_properties.user_properties.dotmenu.saml.edit_link.label": "Edit SAML link", "admin.system_properties.user_properties.dotmenu.saml.link_property.label": "Link attribute to SAML", @@ -3141,8 +3140,11 @@ "admin.system_properties.user_properties.dotmenu.visibility.when_set.label": "Hide when empty", "admin.system_properties.user_properties.subtitle": "Attributes will be shown in user profile and can be used in access control policies.", "admin.system_properties.user_properties.table.actions": "Actions", + "admin.system_properties.user_properties.table.display_name_header": "Display Name", + "admin.system_properties.user_properties.table.display_name.input.label": "Display Name", "admin.system_properties.user_properties.table.filter_type": "Attribute type", - "admin.system_properties.user_properties.table.property": "Attribute", + "admin.system_properties.user_properties.table.identifier.tooltip": "Common Expression Language (CEL) identifier used in policies. Only letters, digits, and underscores allowed. Must start with a letter or underscore. Reserved CEL words are not allowed.", + "admin.system_properties.user_properties.table.name": "Name", "admin.system_properties.user_properties.table.property_name.input.name": "Attribute Name", "admin.system_properties.user_properties.table.select_type.email": "Email", "admin.system_properties.user_properties.table.select_type.multi_select": "Multi-select", @@ -3151,6 +3153,7 @@ "admin.system_properties.user_properties.table.select_type.text": "Text", "admin.system_properties.user_properties.table.select_type.url": "URL", "admin.system_properties.user_properties.table.type": "Type", + "admin.system_properties.user_properties.table.validation.name_invalid_cel": "Identifier must start with a letter or underscore and contain only letters, numbers, and underscores. Reserved CEL words are not allowed.", "admin.system_properties.user_properties.table.validation.name_required": "Please enter an attribute name.", "admin.system_properties.user_properties.table.validation.name_taken": "Attribute name already taken.", "admin.system_properties.user_properties.table.validation.name_unique": "Attribute names must be unique.", diff --git a/webapp/channels/src/utils/properties.test.ts b/webapp/channels/src/utils/properties.test.ts new file mode 100644 index 00000000000..32786f5ff89 --- /dev/null +++ b/webapp/channels/src/utils/properties.test.ts @@ -0,0 +1,276 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import { + CPA_FIELD_NAME_PATTERN, + CPA_FIELD_NAME_RESERVED_WORDS, + filterCELIdentifier, + getUserPropertyFieldLabel, + slugifyForCEL, + validateCPAFieldName, +} from './properties'; + +describe('utils/properties', () => { + describe('getUserPropertyFieldLabel', () => { + const base = {name: 'dept_head'}; + + test('returns display_name when set and non-empty', () => { + expect(getUserPropertyFieldLabel({ + ...base, + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: 'Department Head'}, + })).toBe('Department Head'); + }); + + test('returns display_name trimmed when surrounded by whitespace', () => { + expect(getUserPropertyFieldLabel({ + ...base, + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: ' Department Head '}, + })).toBe('Department Head'); + }); + + test('falls back to name when display_name is undefined', () => { + expect(getUserPropertyFieldLabel({ + ...base, + attrs: {sort_order: 0, visibility: 'always', value_type: ''}, + })).toBe('dept_head'); + }); + + test('falls back to name when display_name is the empty string', () => { + expect(getUserPropertyFieldLabel({ + ...base, + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: ''}, + })).toBe('dept_head'); + }); + + test('falls back to name when display_name is whitespace-only', () => { + expect(getUserPropertyFieldLabel({ + ...base, + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: ' '}, + })).toBe('dept_head'); + }); + + test('falls back to name when attrs is undefined (defensive)', () => { + // Type cast needed since UserPropertyField.attrs is normally required; + // this covers runtime API responses that may omit the field. + expect(getUserPropertyFieldLabel({ + name: 'dept_head', + attrs: undefined as any, + })).toBe('dept_head'); + }); + + test('returns non-ASCII display_name verbatim', () => { + expect(getUserPropertyFieldLabel({ + name: 'employee_number', + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: '员工编号'}, + })).toBe('员工编号'); + expect(getUserPropertyFieldLabel({ + name: 'preferences', + attrs: {sort_order: 0, visibility: 'always', value_type: '', display_name: 'Préférences'}, + })).toBe('Préférences'); + }); + }); +}); + +describe('CPA field name constants - cross-stack drift guard', () => { + it('CPA_FIELD_NAME_PATTERN string matches the Go source exactly', () => { + if (CPA_FIELD_NAME_PATTERN.source !== '^[A-Za-z_][A-Za-z0-9_]*$') { + throw new Error('Update the TS constant to match the Go source at server/public/model/custom_profile_attributes.go'); + } + + expect(CPA_FIELD_NAME_PATTERN.source).toBe('^[A-Za-z_][A-Za-z0-9_]*$'); + }); + + it('CPA_FIELD_NAME_RESERVED_WORDS contains exactly 21 words matching the Go source', () => { + const expected = new Set([ + 'true', 'false', 'null', + 'in', 'as', + 'break', 'const', 'continue', 'else', + 'for', 'function', 'if', 'import', + 'let', 'loop', 'package', 'namespace', + 'return', 'var', 'void', 'while', + ]); + + if (CPA_FIELD_NAME_RESERVED_WORDS.size !== 21) { + throw new Error('Update the TS constant to match the Go source at server/public/model/custom_profile_attributes.go'); + } + + expect(CPA_FIELD_NAME_RESERVED_WORDS.size).toBe(21); + for (const word of expected) { + if (!CPA_FIELD_NAME_RESERVED_WORDS.has(word)) { + throw new Error('Update the TS constant to match the Go source at server/public/model/custom_profile_attributes.go'); + } + + expect(CPA_FIELD_NAME_RESERVED_WORDS.has(word)).toBe(true); + } + for (const word of CPA_FIELD_NAME_RESERVED_WORDS) { + if (!expected.has(word)) { + throw new Error('Update the TS constant to match the Go source at server/public/model/custom_profile_attributes.go'); + } + + expect(expected.has(word)).toBe(true); + } + }); +}); + +describe('validateCPAFieldName', () => { + const validCases = [ + ['simple lowercase', 'department'], + ['leading underscore', '_private'], + ['uppercase start', 'Department'], + ['single uppercase', 'A1'], + ['underscore separator', 'a_b_c'], + ['all uppercase', 'DEPT'], + ['case-sensitive: IN is not reserved', 'IN'], + ['case-sensitive: In is not reserved', 'In'], + ['single lowercase letter', 'a'], + ['single underscore', '_'], + ['single uppercase letter', 'A'], + ['reserved word as prefix', 'trueish'], + ['reserved word as suffix', 'my_null'], + ['255-rune name at exactly max length', 'a'.repeat(255)], + ] as const; + + test.each(validCases)('%s: %s -> null', (_label, input) => { + expect(validateCPAFieldName(input)).toBeNull(); + }); + + const invalidCharsetCases = [ + ['space in name', 'My Field'], + ['leading digit', '7department'], + ['hyphen', 'foo-bar'], + ['emoji', '🎯'], + ['empty string', ''], + ['trailing space', 'name '], + ['non-ASCII letter', 'départment'], + ['whitespace only', ' '], + ['dot separator', 'foo.bar'], + ['slash', 'foo/bar'], + ] as const; + + test.each(invalidCharsetCases)('%s: %s -> invalid_charset', (_label, input) => { + expect(validateCPAFieldName(input)).toEqual({kind: 'invalid_charset'}); + }); + + const reservedWords = [ + 'true', 'false', 'null', + 'in', 'as', + 'break', 'const', 'continue', 'else', + 'for', 'function', 'if', 'import', + 'let', 'loop', 'package', 'namespace', + 'return', 'var', 'void', 'while', + ] as const; + + test.each(reservedWords)('reserved word: %s -> reserved_word', (word) => { + expect(validateCPAFieldName(word)).toEqual({kind: 'reserved_word', word}); + }); + + it('254-rune name -> null', () => { + expect(validateCPAFieldName('a'.repeat(254))).toBeNull(); + }); + + it('255-rune name -> null (exactly at cap)', () => { + expect(validateCPAFieldName('a'.repeat(255))).toBeNull(); + }); + + it('256-rune name -> too_long', () => { + expect(validateCPAFieldName('a'.repeat(256))).toEqual({kind: 'too_long', max: 255}); + }); +}); + +describe('slugifyForCEL', () => { + it('already snake_case identifier passes through unchanged', () => { + expect(slugifyForCEL('dept_head')).toBe('dept_head'); + }); + + it('spaces are replaced with underscores and lowercased', () => { + expect(slugifyForCEL('My Field')).toBe('my_field'); + }); + + it('hyphens are replaced with underscores and lowercased', () => { + expect(slugifyForCEL('foo-Bar')).toBe('foo_bar'); + }); + + it('camelCase is converted to snake_case', () => { + expect(slugifyForCEL('myFieldName')).toBe('my_field_name'); + }); + + it('PascalCase is converted to snake_case', () => { + expect(slugifyForCEL('MyField')).toBe('my_field'); + }); + + it('consecutive uppercase acronyms split before final word', () => { + expect(slugifyForCEL('XMLParser')).toBe('xml_parser'); + expect(slugifyForCEL('HTTPServerError')).toBe('http_server_error'); + }); + + it('all-uppercase token is lowercased without inserting separators', () => { + expect(slugifyForCEL('DEPT')).toBe('dept'); + }); + + it('digit-letter boundaries do not insert separators', () => { + expect(slugifyForCEL('field2name')).toBe('field2name'); + expect(slugifyForCEL('Field2Name')).toBe('field2_name'); + }); + + it('leading digit gets underscore prefix', () => { + expect(slugifyForCEL('7department')).toBe('_7department'); + }); + + it('leading underscore is preserved', () => { + expect(slugifyForCEL('_Private')).toBe('_private'); + }); + + it('empty string returns _copy', () => { + expect(slugifyForCEL('')).toBe('_copy'); + }); + + it('all-punctuation string returns _copy', () => { + expect(slugifyForCEL('---')).toBe('_copy'); + }); + + it('non-ASCII letters are replaced with underscores', () => { + expect(slugifyForCEL('Préférences')).toBe('pr_f_rences'); + }); + + it('result always matches CPA_FIELD_NAME_PATTERN', () => { + const inputs = [ + 'My Field', + 'foo-bar', + '7dept', + '', + '---', + 'valid_name', + 'DEPT', + 'MyField', + 'XMLParser', + 'myFieldName', + '_Private', + 'Préférences', + ]; + + for (const input of inputs) { + const result = slugifyForCEL(input); + expect(CPA_FIELD_NAME_PATTERN.test(result)).toBe(true); + } + }); +}); + +describe('filterCELIdentifier', () => { + const cases = [ + ['strips spaces', 'my field', 'myfield'], + ['strips dashes', 'my-field', 'myfield'], + ['prefixes leading digit', '7department', '_7department'], + ['passes valid identifier through', 'my_field_2', 'my_field_2'], + ['preserves case', 'MyField', 'MyField'], + ['strips emoji', 'field🎉', 'field'], + ['empty string stays empty', '', ''], + ['all digits prefixed', '123', '_123'], + ['all punctuation becomes empty', '!@#', ''], + ['preserves multiple underscores', 'my__field', 'my__field'], + ['leading underscore preserved', '_private', '_private'], + ] as const; + + test.each(cases)('%s: %s → %s', (_label, input, expected) => { + expect(filterCELIdentifier(input)).toBe(expected); + }); +}); diff --git a/webapp/channels/src/utils/properties.ts b/webapp/channels/src/utils/properties.ts new file mode 100644 index 00000000000..37729d7e608 --- /dev/null +++ b/webapp/channels/src/utils/properties.ts @@ -0,0 +1,129 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {UserPropertyField} from '@mattermost/types/properties'; + +/** + * Returns the user-facing label for a CPA field. + * Prefers attrs.display_name (trimmed); falls back to name for legacy + * fields that have not been backfilled yet. + * + * Use for ALL human-readable label rendering (visible text, aria-label, + * title, section headings, etc.). + * + * Do NOT use for: + * - CEL expression construction → use field.name + * - React keys → use field.id or field.name + * - HTML element ids → use field.name or field.id + * - Comparison with currentAttribute in ABAC selectors → use field.name + */ +export function getUserPropertyFieldLabel( + field: Pick, +): string { + const displayName = field.attrs?.display_name?.trim(); + return displayName || field.name; +} + +// SOURCE OF TRUTH: server/public/model/custom_profile_attributes.go lines 81-97 +// CPAFieldNamePattern and CPAFieldNameReservedWords are Go->TS transcriptions. +// If the Go source changes, update BOTH the regex and the Set below, then update +// the hard-coded assertions in properties.test.ts (describe 'CPA field name constants'). +// DO NOT change these constants without a corresponding server-side change. +// +// Grandfather contract: +// CPA name validation only fires when `name` changes from its initial server-persisted value. +// After a successful rename, `current.data[field.id].name` is refreshed by the table's +// readIO.setData(newData) call, which moves the field into the strictly-validated regime +// for all subsequent edits. The regression guard for this contract is in +// user_properties_table.test.tsx - search for 'rename of legacy field clears grandfather'. + +/** + * Mirrors server CPAFieldNamePattern (^[A-Za-z_][A-Za-z0-9_]*$). + * Source: server/public/model/custom_profile_attributes.go:81 + */ +export const CPA_FIELD_NAME_PATTERN = /^[A-Za-z_][A-Za-z0-9_]*$/; + +/** + * Mirrors server CPAFieldNameReservedWords. + * Source: server/public/model/custom_profile_attributes.go:90-97 + * 21 CEL keywords. Case-sensitive: only lowercase forms are reserved. + */ +export const CPA_FIELD_NAME_RESERVED_WORDS = new Set([ + 'true', 'false', 'null', + 'in', 'as', + 'break', 'const', 'continue', 'else', + 'for', 'function', 'if', 'import', + 'let', 'loop', 'package', 'namespace', + 'return', 'var', 'void', 'while', +]); + +/** Max runes for a CPA field name. Mirrors server PropertyFieldNameMaxRunes. */ +export const CPA_FIELD_NAME_MAX_RUNES = 255; + +/** + * Strips characters that are not valid in a CEL identifier and + * prefixes a leading digit with underscore. Unlike slugifyForCEL, + * this does NOT collapse/trim underscores — it is designed for + * live keystroke filtering where the user controls spacing. + */ +export function filterCELIdentifier(input: string): string { + let stripped = input.replace(/[^A-Za-z0-9_]/g, ''); + if (stripped.length > 0 && (/^[0-9]/).test(stripped)) { + stripped = '_' + stripped; + } + return stripped; +} + +export type CPAFieldNameValidationError = + | {kind: 'invalid_charset'} + | {kind: 'reserved_word'; word: string} + | {kind: 'too_long'; max: number}; + +/** + * Client-side mirror of server ValidateCPAFieldName. + * Returns null when the name is valid; returns an error descriptor otherwise. + * + * Length is checked here (against CPA_FIELD_NAME_MAX_RUNES = 255) even though + * the server's ValidateCPAFieldName does not - this provides an early guard + * matching the server's total rejection behavior. + * + * Lenient grandfather: callers must only invoke this when field.name has + * changed from its server-persisted value (mirrors App.PatchCPAField behavior). + */ +export function validateCPAFieldName(name: string): CPAFieldNameValidationError | null { + if ([...name].length > CPA_FIELD_NAME_MAX_RUNES) { + return {kind: 'too_long', max: CPA_FIELD_NAME_MAX_RUNES}; + } + if (!CPA_FIELD_NAME_PATTERN.test(name)) { + return {kind: 'invalid_charset'}; + } + if (CPA_FIELD_NAME_RESERVED_WORDS.has(name)) { + return {kind: 'reserved_word', word: name}; + } + return null; +} + +/** + * Converts an arbitrary string into a snake_case CEL-safe identifier for + * use as a duplicate-field base name. Camel/PascalCase boundaries are + * converted to underscore separators (e.g. 'MyField' -> 'my_field', + * 'XMLParser' -> 'xml_parser'), the result is lowercased, and any + * remaining non-identifier characters are replaced with underscores. + * A leading digit is prefixed with underscore. Consecutive underscores + * collapse to one and trailing underscores are trimmed (a leading + * underscore is preserved). Result is guaranteed to match + * CPA_FIELD_NAME_PATTERN if the input is non-empty; returns '_copy' if + * the entire input normalizes to empty. + */ +export function slugifyForCEL(name: string): string { + let slug = name. + replace(/([a-z0-9])([A-Z])/g, '$1_$2'). + replace(/([A-Z]+)([A-Z][a-z])/g, '$1_$2'). + toLowerCase(). + replace(/[^a-z0-9_]/g, '_'); + if ((/^[0-9]/).test(slug)) { + slug = '_' + slug; + } + slug = slug.replace(/_+/g, '_').replace(/_+$/, ''); + return slug || '_copy'; +} diff --git a/webapp/platform/types/src/properties.ts b/webapp/platform/types/src/properties.ts index ae5e6b55ca0..88c04b3d5ef 100644 --- a/webapp/platform/types/src/properties.ts +++ b/webapp/platform/types/src/properties.ts @@ -84,6 +84,7 @@ export type UserPropertyField = PropertyField & { protected?: boolean; source_plugin_id?: string; access_mode?: '' | 'source_only' | 'shared_only'; + display_name?: string; }; }; From 3f3d8408b2877e73401b8706363eaf1f45a5814c Mon Sep 17 00:00:00 2001 From: Jesse Hallam Date: Fri, 15 May 2026 15:40:25 -0300 Subject: [PATCH 13/80] Return descriptive errors from Role.IsValid and Role.IsValidWithoutId (#36582) * Return descriptive errors from Role.IsValid and Role.IsValidWithoutId Previously both methods returned bool, leaving callers with no context about which validation check failed. Now both return error with a message identifying the specific constraint that was violated. * Add tests for Role.IsValid and Role.IsValidWithoutId * Log migration key on doPermissionsMigration failure --------- Co-authored-by: Mattermost Build --- server/channels/app/permissions_migrations.go | 2 + server/channels/store/sqlstore/role_store.go | 8 +- server/public/model/role.go | 24 +++-- server/public/model/role_test.go | 101 ++++++++++++++++++ 4 files changed, 120 insertions(+), 15 deletions(-) diff --git a/server/channels/app/permissions_migrations.go b/server/channels/app/permissions_migrations.go index 0a7e788b5c2..434fe1d8165 100644 --- a/server/channels/app/permissions_migrations.go +++ b/server/channels/app/permissions_migrations.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" ) @@ -1400,6 +1401,7 @@ func (s *Server) doPermissionsMigrations() error { return err } if err := s.doPermissionsMigration(migration.Key, migMap, roles); err != nil { + mlog.Error("Failed to run permissions migration", mlog.String("key", migration.Key), mlog.Err(err)) return err } } diff --git a/server/channels/store/sqlstore/role_store.go b/server/channels/store/sqlstore/role_store.go index c188cc1c8e1..c6ff9d8a061 100644 --- a/server/channels/store/sqlstore/role_store.go +++ b/server/channels/store/sqlstore/role_store.go @@ -101,8 +101,8 @@ func newSqlRoleStore(sqlStore *SqlStore) store.RoleStore { func (s *SqlRoleStore) Save(role *model.Role) (_ *model.Role, err error) { // Check the role is valid before proceeding. - if !role.IsValidWithoutId() { - return nil, store.NewErrInvalidInput("Role", "", fmt.Sprintf("%v", role)) + if err = role.IsValidWithoutId(); err != nil { + return nil, store.NewErrInvalidInput("Role", "", err.Error()) } if role.Id == "" { @@ -148,8 +148,8 @@ func (s *SqlRoleStore) Save(role *model.Role) (_ *model.Role, err error) { func (s *SqlRoleStore) createRole(role *model.Role, transaction *sqlxTxWrapper) (*model.Role, error) { // Check the role is valid before proceeding. - if !role.IsValidWithoutId() { - return nil, store.NewErrInvalidInput("Role", "", fmt.Sprintf("%v", role)) + if err := role.IsValidWithoutId(); err != nil { + return nil, store.NewErrInvalidInput("Role", "", err.Error()) } dbRole := NewRoleFromModel(role) diff --git a/server/public/model/role.go b/server/public/model/role.go index e7b84a5d829..0513f565c49 100644 --- a/server/public/model/role.go +++ b/server/public/model/role.go @@ -778,25 +778,28 @@ func (r *Role) RolePatchFromChannelModerationsPatch(channelModerationsPatch []*C return &RolePatch{Permissions: &patchPermissions} } -func (r *Role) IsValid() bool { +func (r *Role) IsValid() error { if !IsValidId(r.Id) { - return false + return fmt.Errorf("invalid role id %q", r.Id) } return r.IsValidWithoutId() } -func (r *Role) IsValidWithoutId() bool { +func (r *Role) IsValidWithoutId() error { if !IsValidRoleName(r.Name) { - return false + return fmt.Errorf("invalid role name %q", r.Name) } - if r.DisplayName == "" || len(r.DisplayName) > RoleDisplayNameMaxLength { - return false + if r.DisplayName == "" { + return fmt.Errorf("role display name must not be empty") + } + if len(r.DisplayName) > RoleDisplayNameMaxLength { + return fmt.Errorf("role display name %q exceeds maximum length of %d", r.DisplayName, RoleDisplayNameMaxLength) } if len(r.Description) > RoleDescriptionMaxLength { - return false + return fmt.Errorf("role description exceeds maximum length of %d", RoleDescriptionMaxLength) } check := func(perms []*Permission, permission string) bool { @@ -808,13 +811,12 @@ func (r *Role) IsValidWithoutId() bool { return false } for _, permission := range r.Permissions { - permissionValidated := check(AllPermissions, permission) || check(DeprecatedPermissions, permission) - if !permissionValidated { - return false + if !check(AllPermissions, permission) && !check(DeprecatedPermissions, permission) { + return fmt.Errorf("unknown permission %q", permission) } } - return true + return nil } func CleanRoleNames(roleNames []string) ([]string, bool) { diff --git a/server/public/model/role_test.go b/server/public/model/role_test.go index 9abf2c1c81e..0550509cbba 100644 --- a/server/public/model/role_test.go +++ b/server/public/model/role_test.go @@ -5,6 +5,7 @@ package model import ( "slices" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -362,6 +363,106 @@ func TestManageAgentPermissionsDefinition(t *testing.T) { }), "manage_others_agent should be in AllPermissions") } +func TestRoleIsValidWithoutId(t *testing.T) { + validRole := func() *Role { + return &Role{ + Name: "test_role", + DisplayName: "Test Role", + Description: "A test role.", + Permissions: []string{PermissionCreatePost.Id}, + } + } + + t.Run("valid role returns nil", func(t *testing.T) { + assert.NoError(t, validRole().IsValidWithoutId()) + }) + + t.Run("empty name", func(t *testing.T) { + r := validRole() + r.Name = "" + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("name too long", func(t *testing.T) { + r := validRole() + r.Name = strings.Repeat("a", RoleNameMaxLength+1) + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("name with invalid characters", func(t *testing.T) { + r := validRole() + r.Name = "invalid-name" + assert.ErrorContains(t, r.IsValidWithoutId(), "invalid role name") + }) + + t.Run("empty display name", func(t *testing.T) { + r := validRole() + r.DisplayName = "" + assert.ErrorContains(t, r.IsValidWithoutId(), "display name must not be empty") + }) + + t.Run("display name too long", func(t *testing.T) { + r := validRole() + r.DisplayName = strings.Repeat("a", RoleDisplayNameMaxLength+1) + err := r.IsValidWithoutId() + assert.ErrorContains(t, err, "display name") + assert.ErrorContains(t, err, "exceeds maximum length") + }) + + t.Run("description too long", func(t *testing.T) { + r := validRole() + r.Description = strings.Repeat("a", RoleDescriptionMaxLength+1) + assert.ErrorContains(t, r.IsValidWithoutId(), "description exceeds maximum length") + }) + + t.Run("unknown permission", func(t *testing.T) { + r := validRole() + r.Permissions = []string{"not_a_real_permission"} + err := r.IsValidWithoutId() + require.ErrorContains(t, err, "unknown permission") + assert.ErrorContains(t, err, "not_a_real_permission") + }) + + t.Run("no permissions is valid", func(t *testing.T) { + r := validRole() + r.Permissions = nil + assert.NoError(t, r.IsValidWithoutId()) + }) +} + +func TestRoleIsValid(t *testing.T) { + validRole := func() *Role { + return &Role{ + Id: NewId(), + Name: "test_role", + DisplayName: "Test Role", + Permissions: []string{PermissionCreatePost.Id}, + } + } + + t.Run("valid role returns nil", func(t *testing.T) { + assert.NoError(t, validRole().IsValid()) + }) + + t.Run("empty id", func(t *testing.T) { + r := validRole() + r.Id = "" + assert.ErrorContains(t, r.IsValid(), "invalid role id") + }) + + t.Run("invalid id", func(t *testing.T) { + r := validRole() + r.Id = "not-a-valid-id!" + assert.ErrorContains(t, r.IsValid(), "invalid role id") + }) + + t.Run("propagates IsValidWithoutId error", func(t *testing.T) { + r := validRole() + r.DisplayName = "" + assert.ErrorContains(t, r.IsValid(), "display name must not be empty") + }) +} + func TestManageAgentPermissionsDefaultRoles(t *testing.T) { roles := MakeDefaultRoles() From deafd88fd5a8cf423ee2caa660fd8958673456ef Mon Sep 17 00:00:00 2001 From: Ibrahim Serdar Acikgoz Date: Fri, 15 May 2026 21:04:32 +0200 Subject: [PATCH 14/80] =?UTF-8?q?MM-68762:=20Discoverable=20Private=20Chan?= =?UTF-8?q?nels=20=E2=80=94=20Server=20data=20layer=20(#36539)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MM-68762: Add Postgres migrations for discoverable private channels Three online-safe migrations introduce the schema that supports the Discoverable Private Channels feature (PRs 2-5 of MM-68430 will land behind it): - 000175 adds Channels.Discoverable BOOLEAN NOT NULL DEFAULT FALSE. Metadata-only on Postgres >= 11; no table rewrite. - 000176 creates a partial index on (TeamId) WHERE Discoverable AND Type='P' AND DeleteAt=0 using CREATE INDEX CONCURRENTLY (-- morph:nontransactional) so the build never blocks writes on the populated Channels table. - 000177 creates the ChannelJoinRequests table with three indexes, the important one being the partial unique index on (ChannelId, UserId) WHERE Status = 'pending'. That keeps the full audit history intact while still enforcing at-most-one active pending request per (channel, user). Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add FeatureFlagDiscoverableChannels (default false) Gates the per-channel Discoverable toggle and the channel-join-request flow. Default-OFF so all PRs in the MM-68430 series can land on master without exposing partial UX. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add Discoverable + ChannelJoinRequest models - Channel gains a Discoverable bool, ChannelPatch a *bool, both serialized as 'discoverable'. Patch() applies it, Auditable() logs it, and IsValid() rejects Discoverable=true on any non-private channel so a misconfigured patch can never produce a public discoverable channel. - New ChannelJoinRequest type captures the per-row state of a non-member's request: pending -> approved | denied | withdrawn. Rows are append-only with reviewer and timestamps so the table is also the audit trail. IsValid() enforces: * recognized status, * Message and DenialReason rune limits, * DenialReason only on denied rows (no orphan reasons), * reviewer + reviewed_at present for any terminal review (approved / denied) but not for self-service withdrawal. - Two new WebSocket event constants -- channel_join_request_created and channel_join_request_updated -- that later PRs broadcast on the admin queue and the requester's My Pending Requests panel. Unit tests cover Patch(), the new IsValid() rule on Discoverable, the PreSave/PreUpdate timestamp behavior on ChannelJoinRequest, and every IsValid branch including the reviewer-required-on-review invariant. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add discoverable-channel permissions Two new channel-scoped permissions, each independently rebindable from the System Console: - manage_private_channel_discoverability gates the per-channel toggle so admins can restrict who can flip discoverability without also handing out manage_private_channel_properties. - manage_channel_join_requests gates the queue list / approve / deny / count endpoints (added in PR 2). Both are added to the channel_admin role bootstrap so new deployments get them by default, and a new permissions migration (add_discoverable_channel_permissions) grants them to channel_admin, team_admin and system_admin scheme roles on existing deployments. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add ChannelJoinRequestStore and wire Discoverable into channel store - channelSliceColumns / channelToSlice / updateChannelT now include the new Discoverable column so Save() and Update() round-trip the field. Existing select paths inherit the column automatically because every read goes through channelSliceColumns. - New ChannelJoinRequestStore interface and SQL implementation: Save / Get / GetPendingForChannelAndUser / GetForChannel / GetForUser / Update / CountPending. Save translates the idx_channeljoinrequests_pending_unique partial unique index violation into store.ErrConflict so the app layer (PR 2) can return 409 without re-parsing pq errors. - Storetest suite at storetest/channel_join_request_store.go is invoked from sqlstore via the existing StoreTest harness; covers insert / partial-unique conflict / re-insert after withdrawal / NotFound / status filtering / pagination with TotalCount / Update / CountPending. - Mocks and retrylayer / timerlayer are regenerated via make store-mocks and go generate ./channels/store -- no hand-written generator output. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add TS types for Discoverable channels + join requests webapp/platform/types: - Channel.discoverable?: boolean alongside existing policy_enforced / policy_is_active so the web client sees the same wire shape the server emits. - ChannelJoinRequest, ChannelJoinRequestStatus, ChannelJoinRequestList, GetChannelJoinRequestsOptions for the API contract surfaced in PR 2. webapp/platform/client: - WebSocketEvents enum gains ChannelJoinRequestCreated and ChannelJoinRequestUpdated so PR 3 can hang WS handlers off them without redeclaring constants. These are model-only updates with no UI consumer yet; PR 3 introduces the toggle, request flow, and admin queue surfaces. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Split ChannelJoinRequests indexes into concurrent migrations The mattermost-govet concurrentIndex lint check enforces CREATE INDEX CONCURRENTLY on every CREATE INDEX statement, even on an empty freshly-created table where it would be a no-op. The original 000177 file inlined three CREATE INDEX statements; that failed check-style. Mirror the convention used by 000166_create_views + 000167_create_views_channel_id_delete_at_index: keep the CREATE TABLE in its own (transactional) file, and move each index into a separate nontransactional file that runs CREATE INDEX CONCURRENTLY. Verified locally against Postgres 15 that all four new migrations apply in order and the storetest suite (partial unique constraint + paged list + count) still passes. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Wire new permission migration into test fixtures Two CI test surfaces missed when the channel_admin role and the permission-migration list gained the new manage_private_channel_discoverability and manage_channel_join_requests entries: - testlib/store.go: the shared mocked SystemStore used by SetupWithStoreMock / SetupEnterpriseWithStoreMock needs an explicit GetByName expectation for every migration key (because the mock panics on unexpected calls). Add the new MigrationKeyAddDiscoverableChannelPermissions key so TestCreateOrUpdateAccessControlPolicy, the elasticsearch aggregation_job_test, and every other mock-store test stop panicking on server bootstrap. - cmd/mmctl/commands/permissions_test.go: TestResetPermissionsCmd hard-codes the channel_admin default permission list and expects PatchRole to be called with exactly that slice. Extend the expected slice with the two new permission ids so the mmctl reset path stays in sync with the role bootstrap. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Register new idx_channels_discoverable_team in TestGetSchemaDefinition The schema-dump test asserts an exact index count and definition map for the channels table. Migration 000176 added idx_channels_discoverable_team — a partial btree on (teamid) gated by discoverable=true AND type='P' AND deleteat=0. Bump the expected count from 12 to 13 and add the index's CREATE INDEX definition as produced by pg_indexes (note: type is cast to channel_type, the existing domain). Verified locally against Postgres 15. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Fix golangci-lint findings in ChannelJoinRequest store Two golangci-lint findings on the freshly-added files: - sqlstore/channel_join_request_store.go:133 (modernize): collapse the 'if page < 0 { page = 0 }' clamp into max(opts.Page, 0). - storetest/channel_join_request_store.go:243 (govet shadow): the inner Save loop redeclared err with :=, shadowing the outer err captured from the first CountPending call. Switch to plain assignment so the same err is reused. Verified locally with golangci-lint v2.11.4 across public/..., channels/app/..., channels/store/..., channels/testlib/... and cmd/mmctl/commands/... — 0 issues. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Sync channel_admin bootstrap with TestDoAdvancedPermissionsMigration app_test.go pins the exact list of permissions the channel_admin role is expected to hold after DoAdvancedPermissionsMigration completes. The role bootstrap in role.go grew two entries (manage_private_channel_discoverability and manage_channel_join_requests), so the test's expected slice needs the same two entries appended in the same order, otherwise assert.Equal fails on slice ordering. This is the same class of fix as the mmctl/permissions_test.go change in a previous commit -- two parallel test fixtures encode the channel_admin defaults and have to be updated in lockstep with the bootstrap. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Add English translations for new model error keys 12 keys were emitted by the new Discoverable + ChannelJoinRequest validation paths but had no en.json entry, which trips i18n-check on CI. Add the missing entries with one-line English copy that mirrors adjacent model errors (Invalid ., Create at must be a valid time., etc.). The new entries are: - model.channel.is_valid.discoverable.app_error - model.channel_join_request.is_valid.channel_id.app_error - model.channel_join_request.is_valid.create_at.app_error - model.channel_join_request.is_valid.denial_reason.app_error - model.channel_join_request.is_valid.denial_reason_status.app_error - model.channel_join_request.is_valid.id.app_error - model.channel_join_request.is_valid.message.app_error - model.channel_join_request.is_valid.reviewed_by.app_error - model.channel_join_request.is_valid.reviewer.app_error - model.channel_join_request.is_valid.status.app_error - model.channel_join_request.is_valid.update_at.app_error - model.channel_join_request.is_valid.user_id.app_error Generated through 'make i18n-extract'; verified clean with 'make i18n-check'. Per the workspace rule, only en.json was modified -- no other locale files. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Address CodeRabbit review: stable pagination + redact denial reason from audit log Two production-code findings from CodeRabbit on the freshly-added ChannelJoinRequest server code: - sqlstore/channel_join_request_store.go (GetForChannel / GetForUser): OrderBy("CreateAt DESC") alone is unstable when two rows share a millisecond (NewId is monotonic-ish but CreateAt is millisecond resolution), so offset paging could duplicate or skip rows between pages. Add Id DESC as a deterministic tie-breaker on both list queries. - model/channel_join_request.Auditable: the denial reason is admin-typed free text and could carry sensitive content. Mirror the existing has_message pattern by emitting has_denial_reason as a boolean presence flag instead of the raw value. Reviewer id, review timestamp, and status are still logged, so the audit trail keeps every piece needed for compliance review. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Tighten model tests per CodeRabbit review Two test-only findings from CodeRabbit: - TestChannelJoinRequestPreUpdateAdvancesUpdateAt previously asserted GreaterOrEqual(r.UpdateAt, originalCreate). Because validRequest initialises UpdateAt to GetMillis() (same call site as CreateAt), a no-op PreUpdate would still pass that check. Seed r.UpdateAt = 1 before calling PreUpdate() and assert Greater(r.UpdateAt, int64(1)) so any regression that drops the GetMillis assignment fails the test. - TestChannelIsValidDiscoverable did not cover ChannelTypeGroup. Add the case alongside ChannelTypeOpen and ChannelTypeDirect so the contract that 'only ChannelTypePrivate accepts Discoverable=true' is fully pinned across all four channel types. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68762: Mock ChannelJoinRequest accessor in retrylayer test retrylayer_test.go's genStore() helper mocks every Store() accessor because retrylayer.New() wraps the entire surface. The new ChannelJoinRequest() method I added on Store was missing from the mock, so TestRetry/on_regular_error_should_not_retry panicked with 'Unexpected Method Call ChannelJoinRequest()' on Postgres shard 0. Add the mock alongside the other accessors. No production code change. Co-authored-by: Ibrahim Serdar Acikgoz --------- Co-authored-by: Cursor Agent Co-authored-by: Ibrahim Serdar Acikgoz --- server/channels/app/app_test.go | 2 + server/channels/app/permissions_migrations.go | 17 ++ server/channels/db/migrations/migrations.list | 12 + ...0178_add_discoverable_to_channels.down.sql | 1 + ...000178_add_discoverable_to_channels.up.sql | 1 + ...9_add_channels_discoverable_index.down.sql | 2 + ...179_add_channels_discoverable_index.up.sql | 4 + ...0180_create_channel_join_requests.down.sql | 1 + ...000180_create_channel_join_requests.up.sql | 12 + ...oin_requests_pending_unique_index.down.sql | 2 + ..._join_requests_pending_unique_index.up.sql | 4 + ...oin_requests_channel_status_index.down.sql | 2 + ..._join_requests_channel_status_index.up.sql | 3 + ...l_join_requests_user_status_index.down.sql | 2 + ...nel_join_requests_user_status_index.up.sql | 3 + .../channels/store/retrylayer/retrylayer.go | 158 +++++++++++ .../store/retrylayer/retrylayer_test.go | 1 + .../sqlstore/channel_join_request_store.go | 244 ++++++++++++++++ .../channel_join_request_store_test.go | 14 + .../channels/store/sqlstore/channel_store.go | 5 +- .../store/sqlstore/schema_dump_test.go | 3 +- server/channels/store/sqlstore/store.go | 6 + server/channels/store/store.go | 14 + .../storetest/channel_join_request_store.go | 264 ++++++++++++++++++ .../mocks/ChannelJoinRequestStore.go | 251 +++++++++++++++++ .../channels/store/storetest/mocks/Store.go | 31 +- server/channels/store/storetest/store.go | 5 + .../channels/store/timerlayer/timerlayer.go | 155 ++++++++-- server/channels/testlib/store.go | 1 + server/cmd/mmctl/commands/permissions_test.go | 2 + server/i18n/en.json | 48 ++++ server/public/model/channel.go | 12 + server/public/model/channel_join_request.go | 165 +++++++++++ .../public/model/channel_join_request_test.go | 114 ++++++++ server/public/model/channel_test.go | 58 ++++ server/public/model/feature_flags.go | 7 + server/public/model/migration.go | 1 + server/public/model/permission.go | 16 ++ server/public/model/role.go | 2 + server/public/model/websocket_message.go | 2 + .../platform/client/src/websocket_events.ts | 2 + webapp/platform/types/src/channels.ts | 27 ++ 42 files changed, 1651 insertions(+), 25 deletions(-) create mode 100644 server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql create mode 100644 server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql create mode 100644 server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql create mode 100644 server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql create mode 100644 server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql create mode 100644 server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql create mode 100644 server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql create mode 100644 server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql create mode 100644 server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql create mode 100644 server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql create mode 100644 server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql create mode 100644 server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql create mode 100644 server/channels/store/sqlstore/channel_join_request_store.go create mode 100644 server/channels/store/sqlstore/channel_join_request_store_test.go create mode 100644 server/channels/store/storetest/channel_join_request_store.go create mode 100644 server/channels/store/storetest/mocks/ChannelJoinRequestStore.go create mode 100644 server/public/model/channel_join_request.go create mode 100644 server/public/model/channel_join_request_test.go diff --git a/server/channels/app/app_test.go b/server/channels/app/app_test.go index 0ebfc035610..03792e4acf9 100644 --- a/server/channels/app/app_test.go +++ b/server/channels/app/app_test.go @@ -151,6 +151,8 @@ func TestDoAdvancedPermissionsMigration(t *testing.T) { model.PermissionManageChannelAccessRules.Id, model.PermissionManagePublicChannelAutoTranslation.Id, model.PermissionManagePrivateChannelAutoTranslation.Id, + model.PermissionManagePrivateChannelDiscoverability.Id, + model.PermissionManageChannelJoinRequests.Id, }, "team_user": { model.PermissionListTeamChannels.Id, diff --git a/server/channels/app/permissions_migrations.go b/server/channels/app/permissions_migrations.go index 434fe1d8165..83444fe4c6a 100644 --- a/server/channels/app/permissions_migrations.go +++ b/server/channels/app/permissions_migrations.go @@ -1326,6 +1326,22 @@ func (a *App) getAddEditFileAttachmentPermissionMigration() (permissionsMap, err }, nil } +func (a *App) getAddDiscoverableChannelPermissionsMigration() (permissionsMap, error) { + return permissionsMap{ + permissionTransformation{ + On: permissionOr( + isRole(model.ChannelAdminRoleId), + isRole(model.TeamAdminRoleId), + isRole(model.SystemAdminRoleId), + ), + Add: []string{ + model.PermissionManagePrivateChannelDiscoverability.Id, + model.PermissionManageChannelJoinRequests.Id, + }, + }, + }, nil +} + // DoPermissionsMigrations execute all the permissions migrations need by the current version. func (a *App) DoPermissionsMigrations() error { return a.Srv().doPermissionsMigrations() @@ -1388,6 +1404,7 @@ func (s *Server) doPermissionsMigrations() error { {Key: model.MigrationKeyRestoreManageOAuthPermission, Migration: a.getRestoreManageOAuthPermissionMigration}, {Key: model.MigrationKeyAddManageAgentPermissions, Migration: a.getAddManageAgentPermissionsMigration}, {Key: model.MigrationKeyAddEditFileAttachmentPermission, Migration: a.getAddEditFileAttachmentPermissionMigration}, + {Key: model.MigrationKeyAddDiscoverableChannelPermissions, Migration: a.getAddDiscoverableChannelPermissionsMigration}, } roles, err := s.Store().Role().GetAll() diff --git a/server/channels/db/migrations/migrations.list b/server/channels/db/migrations/migrations.list index 72f6bccb373..04fc52052f5 100644 --- a/server/channels/db/migrations/migrations.list +++ b/server/channels/db/migrations/migrations.list @@ -351,3 +351,15 @@ channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.down.sql channels/db/migrations/postgres/000176_migrate_cpa_to_access_control.up.sql channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.down.sql channels/db/migrations/postgres/000177_filter_attribute_view_by_object_type.up.sql +channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql +channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql +channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql +channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql +channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql +channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql +channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql +channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql +channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql +channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql +channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql +channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql diff --git a/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql new file mode 100644 index 00000000000..98788019071 --- /dev/null +++ b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.down.sql @@ -0,0 +1 @@ +ALTER TABLE Channels DROP COLUMN IF EXISTS Discoverable; diff --git a/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql new file mode 100644 index 00000000000..dce84a520df --- /dev/null +++ b/server/channels/db/migrations/postgres/000178_add_discoverable_to_channels.up.sql @@ -0,0 +1 @@ +ALTER TABLE Channels ADD COLUMN IF NOT EXISTS Discoverable BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql new file mode 100644 index 00000000000..d3d4d6b3545 --- /dev/null +++ b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channels_discoverable_team; diff --git a/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql new file mode 100644 index 00000000000..b838ee35a52 --- /dev/null +++ b/server/channels/db/migrations/postgres/000179_add_channels_discoverable_index.up.sql @@ -0,0 +1,4 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channels_discoverable_team + ON Channels (TeamId) + WHERE Discoverable = true AND Type = 'P' AND DeleteAt = 0; diff --git a/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql new file mode 100644 index 00000000000..8c692c6b8e5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ChannelJoinRequests; diff --git a/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql new file mode 100644 index 00000000000..1e8076dab23 --- /dev/null +++ b/server/channels/db/migrations/postgres/000180_create_channel_join_requests.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS ChannelJoinRequests ( + Id VARCHAR(26) PRIMARY KEY, + ChannelId VARCHAR(26) NOT NULL, + UserId VARCHAR(26) NOT NULL, + Message TEXT NOT NULL DEFAULT '', + Status VARCHAR(16) NOT NULL DEFAULT 'pending', + DenialReason TEXT NOT NULL DEFAULT '', + CreateAt BIGINT NOT NULL, + UpdateAt BIGINT NOT NULL, + ReviewedBy VARCHAR(26) NOT NULL DEFAULT '', + ReviewedAt BIGINT NOT NULL DEFAULT 0 +); diff --git a/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql new file mode 100644 index 00000000000..ca606fc8a74 --- /dev/null +++ b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_pending_unique; diff --git a/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql new file mode 100644 index 00000000000..d2317fecf77 --- /dev/null +++ b/server/channels/db/migrations/postgres/000181_create_channel_join_requests_pending_unique_index.up.sql @@ -0,0 +1,4 @@ +-- morph:nontransactional +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_pending_unique + ON ChannelJoinRequests (ChannelId, UserId) + WHERE Status = 'pending'; diff --git a/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql new file mode 100644 index 00000000000..f5a3ed0da5a --- /dev/null +++ b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_channel_status_createat; diff --git a/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql new file mode 100644 index 00000000000..dbaf927fbb5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql @@ -0,0 +1,3 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_channel_status_createat + ON ChannelJoinRequests (ChannelId, Status, CreateAt DESC); diff --git a/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql new file mode 100644 index 00000000000..134bcc459f5 --- /dev/null +++ b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channeljoinrequests_user_status_createat; diff --git a/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql new file mode 100644 index 00000000000..73271ffa61e --- /dev/null +++ b/server/channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql @@ -0,0 +1,3 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channeljoinrequests_user_status_createat + ON ChannelJoinRequests (UserId, Status, CreateAt DESC); diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index f1b8fe9950d..61970c998ee 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -27,6 +27,7 @@ type RetryLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore CommandStore store.CommandStore @@ -107,6 +108,10 @@ func (s *RetryLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *RetryLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { + return s.ChannelJoinRequestStore +} + func (s *RetryLayer) ChannelMemberHistory() store.ChannelMemberHistoryStore { return s.ChannelMemberHistoryStore } @@ -342,6 +347,11 @@ type RetryLayerChannelBookmarkStore struct { Root *RetryLayer } +type RetryLayerChannelJoinRequestStore struct { + store.ChannelJoinRequestStore + Root *RetryLayer +} + type RetryLayerChannelMemberHistoryStore struct { store.ChannelMemberHistoryStore Root *RetryLayer @@ -3858,6 +3868,153 @@ func (s *RetryLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan } +func (s *RetryLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.CountPending(channelId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Get(id) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + + tries := 0 + for { + result, resultVar1, err := s.ChannelJoinRequestStore.GetForChannel(channelId, opts) + if err == nil { + return result, resultVar1, nil + } + if !isRepeatableError(err) { + return result, resultVar1, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, resultVar1, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + + tries := 0 + for { + result, resultVar1, err := s.ChannelJoinRequestStore.GetForUser(userId, opts) + if err == nil { + return result, resultVar1, nil + } + if !isRepeatableError(err) { + return result, resultVar1, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, resultVar1, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.GetPendingForChannelAndUser(channelId, userId) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Save(req) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + + tries := 0 + for { + result, err := s.ChannelJoinRequestStore.Update(req) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerChannelMemberHistoryStore) DeleteOrphanedRows(limit int) (int64, error) { tries := 0 @@ -18404,6 +18561,7 @@ func New(childStore store.Store) *RetryLayer { newStore.BotStore = &RetryLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &RetryLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &RetryLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelJoinRequestStore = &RetryLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &RetryLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &RetryLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} newStore.CommandStore = &RetryLayerCommandStore{CommandStore: childStore.Command(), Root: &newStore} diff --git a/server/channels/store/retrylayer/retrylayer_test.go b/server/channels/store/retrylayer/retrylayer_test.go index 9c1e08dcfd9..7cb965e53b3 100644 --- a/server/channels/store/retrylayer/retrylayer_test.go +++ b/server/channels/store/retrylayer/retrylayer_test.go @@ -74,6 +74,7 @@ func genStore() *mocks.Store { mock.On("Recap").Return(&mocks.RecapStore{}) mock.On("TemporaryPost").Return(&mocks.TemporaryPostStore{}) mock.On("View").Return(&mocks.ViewStore{}) + mock.On("ChannelJoinRequest").Return(&mocks.ChannelJoinRequestStore{}) return mock } diff --git a/server/channels/store/sqlstore/channel_join_request_store.go b/server/channels/store/sqlstore/channel_join_request_store.go new file mode 100644 index 00000000000..700639fa70a --- /dev/null +++ b/server/channels/store/sqlstore/channel_join_request_store.go @@ -0,0 +1,244 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "database/sql" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/store" + sq "github.com/mattermost/squirrel" + "github.com/pkg/errors" +) + +const channelJoinRequestsTable = "ChannelJoinRequests" + +var channelJoinRequestColumns = []string{ + "Id", + "ChannelId", + "UserId", + "Message", + "Status", + "DenialReason", + "CreateAt", + "UpdateAt", + "ReviewedBy", + "ReviewedAt", +} + +type SqlChannelJoinRequestStore struct { + *SqlStore + + selectQuery sq.SelectBuilder +} + +func newSqlChannelJoinRequestStore(sqlStore *SqlStore) store.ChannelJoinRequestStore { + s := &SqlChannelJoinRequestStore{SqlStore: sqlStore} + s.selectQuery = s.getQueryBuilder(). + Select(channelJoinRequestColumns...). + From(channelJoinRequestsTable) + return s +} + +func (s *SqlChannelJoinRequestStore) toMap(r *model.ChannelJoinRequest) map[string]any { + return map[string]any{ + "Id": r.Id, + "ChannelId": r.ChannelId, + "UserId": r.UserId, + "Message": r.Message, + "Status": r.Status, + "DenialReason": r.DenialReason, + "CreateAt": r.CreateAt, + "UpdateAt": r.UpdateAt, + "ReviewedBy": r.ReviewedBy, + "ReviewedAt": r.ReviewedAt, + } +} + +// Save inserts a new join request. The partial unique index in Postgres +// (channelid, userid) WHERE status = 'pending' enforces at-most-one pending +// row per (channel, user). On conflict we translate the unique-violation into +// a store.ErrConflict so the app layer can return 409. +func (s *SqlChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + req.PreSave() + + if err := req.IsValid(); err != nil { + return nil, err + } + + query := s.getQueryBuilder(). + Insert(channelJoinRequestsTable). + SetMap(s.toMap(req)) + + if _, err := s.GetMaster().ExecBuilder(query); err != nil { + if IsUniqueConstraintError(err, []string{"idx_channeljoinrequests_pending_unique"}) { + return nil, store.NewErrConflict("ChannelJoinRequest", err, "channel_id="+req.ChannelId+" user_id="+req.UserId) + } + return nil, errors.Wrap(err, "failed to save ChannelJoinRequest") + } + + return req, nil +} + +func (s *SqlChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + var req model.ChannelJoinRequest + query := s.selectQuery.Where(sq.Eq{"Id": id}) + + if err := s.GetReplica().GetBuilder(&req, query); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("ChannelJoinRequest", id) + } + return nil, errors.Wrapf(err, "failed to get ChannelJoinRequest with id=%s", id) + } + + return &req, nil +} + +func (s *SqlChannelJoinRequestStore) GetPendingForChannelAndUser(channelId, userId string) (*model.ChannelJoinRequest, error) { + var req model.ChannelJoinRequest + query := s.selectQuery.Where(sq.Eq{ + "ChannelId": channelId, + "UserId": userId, + "Status": model.ChannelJoinRequestStatusPending, + }) + + if err := s.GetReplica().GetBuilder(&req, query); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("ChannelJoinRequest", "channel_id="+channelId+" user_id="+userId) + } + return nil, errors.Wrapf(err, "failed to get pending ChannelJoinRequest for channel_id=%s user_id=%s", channelId, userId) + } + + return &req, nil +} + +// applyStatusFilter applies the opts.Status filter (defaulting to pending if empty) +// to both the select and count queries. Returning the two filtered builders keeps +// list and count perfectly in sync. +func applyJoinRequestStatusFilter(opts model.GetChannelJoinRequestsOpts) sq.Eq { + status := opts.Status + if status == "" { + status = model.ChannelJoinRequestStatusPending + } + return sq.Eq{"Status": status} +} + +func paginate(opts model.GetChannelJoinRequestsOpts) (limit, offset uint64) { + perPage := opts.PerPage + if perPage <= 0 { + perPage = 60 + } + page := max(opts.Page, 0) + return uint64(perPage), uint64(page) * uint64(perPage) +} + +func (s *SqlChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + where := sq.And{sq.Eq{"ChannelId": channelId}, applyJoinRequestStatusFilter(opts)} + + limit, offset := paginate(opts) + listQuery := s.selectQuery. + Where(where). + OrderBy("CreateAt DESC", "Id DESC"). + Limit(limit). + Offset(offset) + + var rows []*model.ChannelJoinRequest + if err := s.GetReplica().SelectBuilder(&rows, listQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to list ChannelJoinRequests for channel_id=%s", channelId) + } + + countQuery := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(where) + + var total int64 + if err := s.GetReplica().GetBuilder(&total, countQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to count ChannelJoinRequests for channel_id=%s", channelId) + } + + return rows, total, nil +} + +func (s *SqlChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + where := sq.And{sq.Eq{"UserId": userId}, applyJoinRequestStatusFilter(opts)} + + limit, offset := paginate(opts) + listQuery := s.selectQuery. + Where(where). + OrderBy("CreateAt DESC", "Id DESC"). + Limit(limit). + Offset(offset) + + var rows []*model.ChannelJoinRequest + if err := s.GetReplica().SelectBuilder(&rows, listQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to list ChannelJoinRequests for user_id=%s", userId) + } + + countQuery := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(where) + + var total int64 + if err := s.GetReplica().GetBuilder(&total, countQuery); err != nil { + return nil, 0, errors.Wrapf(err, "failed to count ChannelJoinRequests for user_id=%s", userId) + } + + return rows, total, nil +} + +// Update writes the mutable fields back. Id/ChannelId/UserId/CreateAt are +// immutable post-create — the partial-unique index relies on (ChannelId, UserId) +// being stable for the lifetime of a row. +func (s *SqlChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + req.PreUpdate() + + if err := req.IsValid(); err != nil { + return nil, err + } + + query := s.getQueryBuilder(). + Update(channelJoinRequestsTable). + SetMap(map[string]any{ + "Status": req.Status, + "Message": req.Message, + "DenialReason": req.DenialReason, + "UpdateAt": req.UpdateAt, + "ReviewedBy": req.ReviewedBy, + "ReviewedAt": req.ReviewedAt, + }). + Where(sq.Eq{"Id": req.Id}) + + res, err := s.GetMaster().ExecBuilder(query) + if err != nil { + return nil, errors.Wrapf(err, "failed to update ChannelJoinRequest with id=%s", req.Id) + } + + n, err := res.RowsAffected() + if err != nil { + return nil, errors.Wrap(err, "failed to read RowsAffected on ChannelJoinRequest update") + } + if n == 0 { + return nil, store.NewErrNotFound("ChannelJoinRequest", req.Id) + } + + return req, nil +} + +func (s *SqlChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + query := s.getQueryBuilder(). + Select("COUNT(*)"). + From(channelJoinRequestsTable). + Where(sq.Eq{ + "ChannelId": channelId, + "Status": model.ChannelJoinRequestStatusPending, + }) + + var count int64 + if err := s.GetReplica().GetBuilder(&count, query); err != nil { + return 0, errors.Wrapf(err, "failed to count pending ChannelJoinRequests for channel_id=%s", channelId) + } + return count, nil +} diff --git a/server/channels/store/sqlstore/channel_join_request_store_test.go b/server/channels/store/sqlstore/channel_join_request_store_test.go new file mode 100644 index 00000000000..bbbfdc8f52e --- /dev/null +++ b/server/channels/store/sqlstore/channel_join_request_store_test.go @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "testing" + + "github.com/mattermost/mattermost/server/v8/channels/store/storetest" +) + +func TestChannelJoinRequestStore(t *testing.T) { + StoreTest(t, storetest.TestChannelJoinRequestStore) +} diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 4691f73df3f..d98f215dacb 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -158,6 +158,7 @@ func channelSliceColumns(isSelect bool, prefix ...string) []string { p + "LastRootPostAt", p + "BannerInfo", p + "DefaultCategoryName", + p + "Discoverable", } if isSelect { @@ -196,6 +197,7 @@ func channelToSlice(channel *model.Channel) []any { channel.LastRootPostAt, channel.BannerInfo, channel.DefaultCategoryName, + channel.Discoverable, } } @@ -872,7 +874,8 @@ func (s SqlChannelStore) updateChannelT(transaction *sqlxTxWrapper, channel *mod LastRootPostAt=:LastRootPostAt, BannerInfo=:BannerInfo, DefaultCategoryName=:DefaultCategoryName, - AutoTranslation=:AutoTranslation + AutoTranslation=:AutoTranslation, + Discoverable=:Discoverable WHERE Id=:Id`, channel) if err != nil { if IsUniqueConstraintError(err, []string{"Name", "channels_name_teamid_key"}) { diff --git a/server/channels/store/sqlstore/schema_dump_test.go b/server/channels/store/sqlstore/schema_dump_test.go index 4ca5ee3f200..a4b52ec2e62 100644 --- a/server/channels/store/sqlstore/schema_dump_test.go +++ b/server/channels/store/sqlstore/schema_dump_test.go @@ -72,7 +72,7 @@ func TestGetSchemaDefinition(t *testing.T) { if table.Name == "channels" { // Check that indexes are present assert.NotEmpty(t, table.Indexes, "channels table should have indexes") - assert.Equal(t, 12, len(table.Indexes), "channels table should have 12 indexes") + assert.Equal(t, 13, len(table.Indexes), "channels table should have 13 indexes") // Expected index definitions expectedIndexDefs := map[string]string{ @@ -88,6 +88,7 @@ func TestGetSchemaDefinition(t *testing.T) { "idx_channels_team_id_display_name": "CREATE INDEX idx_channels_team_id_display_name ON public.channels USING btree (teamid, displayname)", "idx_channels_team_id_type": "CREATE INDEX idx_channels_team_id_type ON public.channels USING btree (teamid, type)", "idx_channels_autotranslation_enabled": "CREATE INDEX idx_channels_autotranslation_enabled ON public.channels USING btree (id) WHERE (autotranslation = true)", + "idx_channels_discoverable_team": "CREATE INDEX idx_channels_discoverable_team ON public.channels USING btree (teamid) WHERE ((discoverable = true) AND (type = 'P'::channel_type) AND (deleteat = 0))", } // Verify all expected indexes are present with correct definitions diff --git a/server/channels/store/sqlstore/store.go b/server/channels/store/sqlstore/store.go index d909da831e6..0bf3b4a3200 100644 --- a/server/channels/store/sqlstore/store.go +++ b/server/channels/store/sqlstore/store.go @@ -117,6 +117,7 @@ type SqlStoreStores struct { recap store.RecapStore readReceipt store.ReadReceiptStore temporaryPost store.TemporaryPostStore + channelJoinRequest store.ChannelJoinRequestStore } type SqlStore struct { @@ -303,6 +304,7 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface store.stores.recap = newSqlRecapStore(store) store.stores.readReceipt = newSqlReadReceiptStore(store, metrics) store.stores.temporaryPost = newSqlTemporaryPostStore(store, metrics) + store.stores.channelJoinRequest = newSqlChannelJoinRequestStore(store) store.stores.preference.(*SqlPreferenceStore).deleteUnusedFeatures() @@ -926,6 +928,10 @@ func (ss *SqlStore) TemporaryPost() store.TemporaryPostStore { return ss.stores.temporaryPost } +func (ss *SqlStore) ChannelJoinRequest() store.ChannelJoinRequestStore { + return ss.stores.channelJoinRequest +} + func (ss *SqlStore) DropAllTables() { ss.masterX.Exec(`DO $func$ diff --git a/server/channels/store/store.go b/server/channels/store/store.go index d09b52d8895..331595c3dc1 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -102,6 +102,7 @@ type Store interface { Recap() RecapStore ReadReceipt() ReadReceiptStore TemporaryPost() TemporaryPostStore + ChannelJoinRequest() ChannelJoinRequestStore } type RetentionPolicyStore interface { @@ -1332,6 +1333,19 @@ type ThreadMembershipImportData struct { UnreadMentions int64 } +// ChannelJoinRequestStore persists user requests to join discoverable private +// channels. Rows are never deleted; status transitions are recorded with +// reviewer and timestamps so the table doubles as an audit trail. +type ChannelJoinRequestStore interface { + Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) + Get(id string) (*model.ChannelJoinRequest, error) + GetPendingForChannelAndUser(channelId, userId string) (*model.ChannelJoinRequest, error) + GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) + GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) + Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) + CountPending(channelId string) (int64, error) +} + type RecapStore interface { SaveRecap(recap *model.Recap) (*model.Recap, error) UpdateRecap(recap *model.Recap) (*model.Recap, error) diff --git a/server/channels/store/storetest/channel_join_request_store.go b/server/channels/store/storetest/channel_join_request_store.go new file mode 100644 index 00000000000..de716535742 --- /dev/null +++ b/server/channels/store/storetest/channel_join_request_store.go @@ -0,0 +1,264 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package storetest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +func TestChannelJoinRequestStore(t *testing.T, _ request.CTX, ss store.Store) { + t.Run("Save inserts a pending row", testChannelJoinRequestSave(ss)) + t.Run("Save rejects duplicate pending row", testChannelJoinRequestSaveDuplicatePending(ss)) + t.Run("Save allows another pending row after withdrawal", testChannelJoinRequestSaveAfterWithdraw(ss)) + t.Run("Get returns NotFound for unknown id", testChannelJoinRequestGetNotFound(ss)) + t.Run("GetPendingForChannelAndUser only returns pending rows", testChannelJoinRequestGetPending(ss)) + t.Run("GetForChannel paginates and filters by status", testChannelJoinRequestGetForChannel(ss)) + t.Run("GetForUser paginates and filters by status", testChannelJoinRequestGetForUser(ss)) + t.Run("Update transitions status and stores reviewer", testChannelJoinRequestUpdate(ss)) + t.Run("CountPending only counts pending rows", testChannelJoinRequestCountPending(ss)) +} + +func newPendingRequest(channelId, userId string) *model.ChannelJoinRequest { + return &model.ChannelJoinRequest{ + ChannelId: channelId, + UserId: userId, + Message: "please let me in", + Status: model.ChannelJoinRequestStatusPending, + } +} + +func testChannelJoinRequestSave(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + req, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + require.NotEmpty(t, req.Id) + assert.Equal(t, channelId, req.ChannelId) + assert.Equal(t, userId, req.UserId) + assert.Equal(t, model.ChannelJoinRequestStatusPending, req.Status) + assert.NotZero(t, req.CreateAt) + assert.Equal(t, req.CreateAt, req.UpdateAt) + + fetched, err := ss.ChannelJoinRequest().Get(req.Id) + require.NoError(t, err) + assert.Equal(t, req.Id, fetched.Id) + assert.Equal(t, req.Message, fetched.Message) + assert.Equal(t, req.Status, fetched.Status) + } +} + +func testChannelJoinRequestSaveDuplicatePending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.Error(t, err) + var conflict *store.ErrConflict + assert.ErrorAs(t, err, &conflict, "duplicate pending row must surface store.ErrConflict") + } +} + +func testChannelJoinRequestSaveAfterWithdraw(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + first, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + first.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(first) + require.NoError(t, err) + + // Allow the millisecond-resolution UpdateAt to advance. + time.Sleep(2 * time.Millisecond) + + second, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err, "a new pending row must be insertable once the previous one is no longer pending") + assert.NotEqual(t, first.Id, second.Id) + } +} + +func testChannelJoinRequestGetNotFound(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + _, err := ss.ChannelJoinRequest().Get(model.NewId()) + require.Error(t, err) + var nf *store.ErrNotFound + assert.ErrorAs(t, err, &nf) + } +} + +func testChannelJoinRequestGetPending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + _, err := ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.Error(t, err, "must return NotFound when no row exists") + + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, userId)) + require.NoError(t, err) + + got, err := ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.NoError(t, err) + assert.Equal(t, channelId, got.ChannelId) + assert.Equal(t, userId, got.UserId) + assert.Equal(t, model.ChannelJoinRequestStatusPending, got.Status) + + got.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(got) + require.NoError(t, err) + + _, err = ss.ChannelJoinRequest().GetPendingForChannelAndUser(channelId, userId) + require.Error(t, err, "withdrawn row must not be considered pending") + } +} + +func testChannelJoinRequestGetForChannel(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + + // Three pending requests across distinct users + one denied row for the + // same channel so we can prove the status filter actually filters. + for range 3 { + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(channelId, model.NewId())) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + } + + denied := newPendingRequest(channelId, model.NewId()) + saved, err := ss.ChannelJoinRequest().Save(denied) + require.NoError(t, err) + saved.Status = model.ChannelJoinRequestStatusDenied + saved.ReviewedBy = model.NewId() + saved.ReviewedAt = model.GetMillis() + saved.DenialReason = "policy mismatch" + _, err = ss.ChannelJoinRequest().Update(saved) + require.NoError(t, err) + + rows, total, err := ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + assert.Len(t, rows, 3) + assert.Equal(t, int64(3), total) + for i := 1; i < len(rows); i++ { + assert.GreaterOrEqual(t, rows[i-1].CreateAt, rows[i].CreateAt, "list should be newest first") + } + + rows, total, err = ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{Status: model.ChannelJoinRequestStatusDenied, PerPage: 10}) + require.NoError(t, err) + assert.Equal(t, int64(1), total) + require.Len(t, rows, 1) + assert.Equal(t, "policy mismatch", rows[0].DenialReason) + + rows, total, err = ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 2, Page: 0}) + require.NoError(t, err) + assert.Len(t, rows, 2) + assert.Equal(t, int64(3), total, "TotalCount must not be truncated by paging") + } +} + +func testChannelJoinRequestGetForUser(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + userId := model.NewId() + + for range 2 { + _, err := ss.ChannelJoinRequest().Save(newPendingRequest(model.NewId(), userId)) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + } + + denied := newPendingRequest(model.NewId(), userId) + saved, err := ss.ChannelJoinRequest().Save(denied) + require.NoError(t, err) + saved.Status = model.ChannelJoinRequestStatusDenied + saved.ReviewedBy = model.NewId() + saved.ReviewedAt = model.GetMillis() + _, err = ss.ChannelJoinRequest().Update(saved) + require.NoError(t, err) + + rows, total, err := ss.ChannelJoinRequest().GetForUser(userId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + assert.Len(t, rows, 2) + assert.Equal(t, int64(2), total) + + rows, total, err = ss.ChannelJoinRequest().GetForUser(userId, model.GetChannelJoinRequestsOpts{Status: model.ChannelJoinRequestStatusDenied, PerPage: 10}) + require.NoError(t, err) + assert.Equal(t, int64(1), total) + assert.Len(t, rows, 1) + } +} + +func testChannelJoinRequestUpdate(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + req, err := ss.ChannelJoinRequest().Save(newPendingRequest(model.NewId(), model.NewId())) + require.NoError(t, err) + originalUpdateAt := req.UpdateAt + + reviewerId := model.NewId() + reviewedAt := model.GetMillis() + 1 + req.Status = model.ChannelJoinRequestStatusApproved + req.ReviewedBy = reviewerId + req.ReviewedAt = reviewedAt + + // Allow UpdateAt to advance. + time.Sleep(2 * time.Millisecond) + updated, err := ss.ChannelJoinRequest().Update(req) + require.NoError(t, err) + assert.Equal(t, model.ChannelJoinRequestStatusApproved, updated.Status) + assert.Equal(t, reviewerId, updated.ReviewedBy) + assert.Equal(t, reviewedAt, updated.ReviewedAt) + assert.Greater(t, updated.UpdateAt, originalUpdateAt) + + fetched, err := ss.ChannelJoinRequest().Get(req.Id) + require.NoError(t, err) + assert.Equal(t, model.ChannelJoinRequestStatusApproved, fetched.Status) + assert.Equal(t, reviewerId, fetched.ReviewedBy) + } +} + +func testChannelJoinRequestCountPending(ss store.Store) func(*testing.T) { + return func(t *testing.T) { + channelId := model.NewId() + + count, err := ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(0), count) + + for range 4 { + _, err = ss.ChannelJoinRequest().Save(newPendingRequest(channelId, model.NewId())) + require.NoError(t, err) + } + + count, err = ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(4), count) + + // Withdraw one — count should drop by 1. + reqs, _, err := ss.ChannelJoinRequest().GetForChannel(channelId, model.GetChannelJoinRequestsOpts{PerPage: 10}) + require.NoError(t, err) + require.NotEmpty(t, reqs) + first := reqs[0] + first.Status = model.ChannelJoinRequestStatusWithdrawn + _, err = ss.ChannelJoinRequest().Update(first) + require.NoError(t, err) + + count, err = ss.ChannelJoinRequest().CountPending(channelId) + require.NoError(t, err) + assert.Equal(t, int64(3), count) + } +} diff --git a/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go b/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go new file mode 100644 index 00000000000..23d71a295a3 --- /dev/null +++ b/server/channels/store/storetest/mocks/ChannelJoinRequestStore.go @@ -0,0 +1,251 @@ +// Code generated by mockery v2.53.4. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import ( + model "github.com/mattermost/mattermost/server/public/model" + mock "github.com/stretchr/testify/mock" +) + +// ChannelJoinRequestStore is an autogenerated mock type for the ChannelJoinRequestStore type +type ChannelJoinRequestStore struct { + mock.Mock +} + +// CountPending provides a mock function with given fields: channelId +func (_m *ChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + ret := _m.Called(channelId) + + if len(ret) == 0 { + panic("no return value specified for CountPending") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(string) (int64, error)); ok { + return rf(channelId) + } + if rf, ok := ret.Get(0).(func(string) int64); ok { + r0 = rf(channelId) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: id +func (_m *ChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + ret := _m.Called(id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(string) (*model.ChannelJoinRequest, error)); ok { + return rf(id) + } + if rf, ok := ret.Get(0).(func(string) *model.ChannelJoinRequest); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForChannel provides a mock function with given fields: channelId, opts +func (_m *ChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + ret := _m.Called(channelId, opts) + + if len(ret) == 0 { + panic("no return value specified for GetForChannel") + } + + var r0 []*model.ChannelJoinRequest + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error)); ok { + return rf(channelId, opts) + } + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) []*model.ChannelJoinRequest); ok { + r0 = rf(channelId, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, model.GetChannelJoinRequestsOpts) int64); ok { + r1 = rf(channelId, opts) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(string, model.GetChannelJoinRequestsOpts) error); ok { + r2 = rf(channelId, opts) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetForUser provides a mock function with given fields: userId, opts +func (_m *ChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + ret := _m.Called(userId, opts) + + if len(ret) == 0 { + panic("no return value specified for GetForUser") + } + + var r0 []*model.ChannelJoinRequest + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error)); ok { + return rf(userId, opts) + } + if rf, ok := ret.Get(0).(func(string, model.GetChannelJoinRequestsOpts) []*model.ChannelJoinRequest); ok { + r0 = rf(userId, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, model.GetChannelJoinRequestsOpts) int64); ok { + r1 = rf(userId, opts) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(string, model.GetChannelJoinRequestsOpts) error); ok { + r2 = rf(userId, opts) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetPendingForChannelAndUser provides a mock function with given fields: channelId, userId +func (_m *ChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + ret := _m.Called(channelId, userId) + + if len(ret) == 0 { + panic("no return value specified for GetPendingForChannelAndUser") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*model.ChannelJoinRequest, error)); ok { + return rf(channelId, userId) + } + if rf, ok := ret.Get(0).(func(string, string) *model.ChannelJoinRequest); ok { + r0 = rf(channelId, userId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channelId, userId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: req +func (_m *ChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + ret := _m.Called(req) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) (*model.ChannelJoinRequest, error)); ok { + return rf(req) + } + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) *model.ChannelJoinRequest); ok { + r0 = rf(req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(*model.ChannelJoinRequest) error); ok { + r1 = rf(req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Update provides a mock function with given fields: req +func (_m *ChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + ret := _m.Called(req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *model.ChannelJoinRequest + var r1 error + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) (*model.ChannelJoinRequest, error)); ok { + return rf(req) + } + if rf, ok := ret.Get(0).(func(*model.ChannelJoinRequest) *model.ChannelJoinRequest); ok { + r0 = rf(req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelJoinRequest) + } + } + + if rf, ok := ret.Get(1).(func(*model.ChannelJoinRequest) error); ok { + r1 = rf(req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewChannelJoinRequestStore creates a new instance of ChannelJoinRequestStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewChannelJoinRequestStore(t interface { + mock.TestingT + Cleanup(func()) +}) *ChannelJoinRequestStore { + mock := &ChannelJoinRequestStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/server/channels/store/storetest/mocks/Store.go b/server/channels/store/storetest/mocks/Store.go index 30f21da0292..2a9565d9fbd 100644 --- a/server/channels/store/storetest/mocks/Store.go +++ b/server/channels/store/storetest/mocks/Store.go @@ -5,16 +5,13 @@ package mocks import ( - mlog "github.com/mattermost/mattermost/server/public/shared/mlog" - mock "github.com/stretchr/testify/mock" + sql "database/sql" + time "time" model "github.com/mattermost/mattermost/server/public/model" - - sql "database/sql" - + mlog "github.com/mattermost/mattermost/server/public/shared/mlog" store "github.com/mattermost/mattermost/server/v8/channels/store" - - time "time" + mock "github.com/stretchr/testify/mock" ) // Store is an autogenerated mock type for the Store type @@ -162,6 +159,26 @@ func (_m *Store) ChannelBookmark() store.ChannelBookmarkStore { return r0 } +// ChannelJoinRequest provides a mock function with no fields +func (_m *Store) ChannelJoinRequest() store.ChannelJoinRequestStore { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ChannelJoinRequest") + } + + var r0 store.ChannelJoinRequestStore + if rf, ok := ret.Get(0).(func() store.ChannelJoinRequestStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.ChannelJoinRequestStore) + } + } + + return r0 +} + // ChannelMemberHistory provides a mock function with no fields func (_m *Store) ChannelMemberHistory() store.ChannelMemberHistoryStore { ret := _m.Called() diff --git a/server/channels/store/storetest/store.go b/server/channels/store/storetest/store.go index be8bf0f1f14..5c54760f6d7 100644 --- a/server/channels/store/storetest/store.go +++ b/server/channels/store/storetest/store.go @@ -75,6 +75,7 @@ type Store struct { ReadReceiptStore mocks.ReadReceiptStore TemporaryPostStore mocks.TemporaryPostStore ViewStore mocks.ViewStore + ChannelJoinRequestStore mocks.ChannelJoinRequestStore } func (s *Store) Logger() mlog.LoggerIFace { return s.logger } @@ -180,6 +181,9 @@ func (s *Store) ReadReceipt() store.ReadReceiptStore { func (s *Store) TemporaryPost() store.TemporaryPostStore { return &s.TemporaryPostStore } +func (s *Store) ChannelJoinRequest() store.ChannelJoinRequestStore { + return &s.ChannelJoinRequestStore +} func (s *Store) View() store.ViewStore { return &s.ViewStore } @@ -239,5 +243,6 @@ func (s *Store) AssertExpectations(t mock.TestingT) bool { &s.ReadReceiptStore, &s.TemporaryPostStore, &s.ViewStore, + &s.ChannelJoinRequestStore, ) } diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 82f7f583b7a..ced6f298f9d 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -26,6 +26,7 @@ type TimerLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore CommandStore store.CommandStore @@ -106,6 +107,10 @@ func (s *TimerLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *TimerLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { + return s.ChannelJoinRequestStore +} + func (s *TimerLayer) ChannelMemberHistory() store.ChannelMemberHistoryStore { return s.ChannelMemberHistoryStore } @@ -341,6 +346,11 @@ type TimerLayerChannelBookmarkStore struct { Root *TimerLayer } +type TimerLayerChannelJoinRequestStore struct { + store.ChannelJoinRequestStore + Root *TimerLayer +} + type TimerLayerChannelMemberHistoryStore struct { store.ChannelMemberHistoryStore Root *TimerLayer @@ -3218,6 +3228,118 @@ func (s *TimerLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan return result, err } +func (s *TimerLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.CountPending(channelId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.CountPending", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Get(id string) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Get(id) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Get", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetForChannel(channelId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + start := time.Now() + + result, resultVar1, err := s.ChannelJoinRequestStore.GetForChannel(channelId, opts) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetForChannel", success, elapsed) + } + return result, resultVar1, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetForUser(userId string, opts model.GetChannelJoinRequestsOpts) ([]*model.ChannelJoinRequest, int64, error) { + start := time.Now() + + result, resultVar1, err := s.ChannelJoinRequestStore.GetForUser(userId, opts) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetForUser", success, elapsed) + } + return result, resultVar1, err +} + +func (s *TimerLayerChannelJoinRequestStore) GetPendingForChannelAndUser(channelId string, userId string) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.GetPendingForChannelAndUser(channelId, userId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.GetPendingForChannelAndUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Save(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Save(req) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Save", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelJoinRequestStore) Update(req *model.ChannelJoinRequest) (*model.ChannelJoinRequest, error) { + start := time.Now() + + result, err := s.ChannelJoinRequestStore.Update(req) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelJoinRequestStore.Update", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelMemberHistoryStore) DeleteOrphanedRows(limit int) (int64, error) { start := time.Now() @@ -3250,22 +3372,6 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsLeftSince(userID string return result, err } -func (s *TimerLayerChannelMemberHistoryStore) GetEverMembersInChannel(channelID string, userIDs []string) ([]string, error) { - start := time.Now() - - result, err := s.ChannelMemberHistoryStore.GetEverMembersInChannel(channelID, userIDs) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetEverMembersInChannel", success, elapsed) - } - return result, err -} - func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(startTime int64, endTime int64) ([]string, error) { start := time.Now() @@ -3282,6 +3388,22 @@ func (s *TimerLayerChannelMemberHistoryStore) GetChannelsWithActivityDuring(star return result, err } +func (s *TimerLayerChannelMemberHistoryStore) GetEverMembersInChannel(channelID string, userIDs []string) ([]string, error) { + start := time.Now() + + result, err := s.ChannelMemberHistoryStore.GetEverMembersInChannel(channelID, userIDs) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelMemberHistoryStore.GetEverMembersInChannel", success, elapsed) + } + return result, err +} + func (s *TimerLayerChannelMemberHistoryStore) GetMembershipChanges(channelID string, since int64, limit int) ([]*model.ChannelMemberHistory, error) { start := time.Now() @@ -14558,6 +14680,7 @@ func New(childStore store.Store, metrics einterfaces.MetricsInterface) *TimerLay newStore.BotStore = &TimerLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &TimerLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &TimerLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelJoinRequestStore = &TimerLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &TimerLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &TimerLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} newStore.CommandStore = &TimerLayerCommandStore{CommandStore: childStore.Command(), Root: &newStore} diff --git a/server/channels/testlib/store.go b/server/channels/testlib/store.go index 9aee6d57e96..5d01e25c6ab 100644 --- a/server/channels/testlib/store.go +++ b/server/channels/testlib/store.go @@ -102,6 +102,7 @@ func GetMockStoreForSetupFunctions() *mocks.Store { systemStore.On("GetByName", model.MigrationKeyAccessControlPolicyV0_3).Return(&model.System{Name: model.MigrationKeyAccessControlPolicyV0_3, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddManageAgentPermissions).Return(&model.System{Name: model.MigrationKeyAddManageAgentPermissions, Value: "true"}, nil) systemStore.On("GetByName", model.MigrationKeyAddEditFileAttachmentPermission).Return(&model.System{Name: model.MigrationKeyAddEditFileAttachmentPermission, Value: "true"}, nil) + systemStore.On("GetByName", model.MigrationKeyAddDiscoverableChannelPermissions).Return(&model.System{Name: model.MigrationKeyAddDiscoverableChannelPermissions, Value: "true"}, nil) systemStore.On("InsertIfExists", mock.AnythingOfType("*model.System")).Return(&model.System{}, nil).Once() systemStore.On("Save", mock.AnythingOfType("*model.System")).Return(nil) diff --git a/server/cmd/mmctl/commands/permissions_test.go b/server/cmd/mmctl/commands/permissions_test.go index 51cbd4f5332..671ba68b051 100644 --- a/server/cmd/mmctl/commands/permissions_test.go +++ b/server/cmd/mmctl/commands/permissions_test.go @@ -251,6 +251,8 @@ func (s *MmctlUnitTestSuite) TestResetPermissionsCmd() { "manage_channel_access_rules", "manage_public_channel_auto_translation", "manage_private_channel_auto_translation", + "manage_private_channel_discoverability", + "manage_channel_join_requests", } expectedPatch := &model.RolePatch{ Permissions: &expectedPermissions, diff --git a/server/i18n/en.json b/server/i18n/en.json index 15a5aa99763..b46a9eb05d3 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -10738,6 +10738,10 @@ "id": "model.channel.is_valid.creator_id.app_error", "translation": "Invalid creator id." }, + { + "id": "model.channel.is_valid.discoverable.app_error", + "translation": "Only private channels can be marked as discoverable." + }, { "id": "model.channel.is_valid.display_name.app_error", "translation": "Invalid display name." @@ -10830,6 +10834,50 @@ "id": "model.channel_bookmark.is_valid.update_at.app_error", "translation": "Update at must be a valid time." }, + { + "id": "model.channel_join_request.is_valid.channel_id.app_error", + "translation": "Invalid channel id." + }, + { + "id": "model.channel_join_request.is_valid.create_at.app_error", + "translation": "Create at must be a valid time." + }, + { + "id": "model.channel_join_request.is_valid.denial_reason.app_error", + "translation": "Denial reason is too long." + }, + { + "id": "model.channel_join_request.is_valid.denial_reason_status.app_error", + "translation": "Denial reason can only be set on a denied join request." + }, + { + "id": "model.channel_join_request.is_valid.id.app_error", + "translation": "Invalid Id." + }, + { + "id": "model.channel_join_request.is_valid.message.app_error", + "translation": "Join request message is too long." + }, + { + "id": "model.channel_join_request.is_valid.reviewed_by.app_error", + "translation": "Invalid reviewer id." + }, + { + "id": "model.channel_join_request.is_valid.reviewer.app_error", + "translation": "An approved or denied join request must record the reviewer and review time." + }, + { + "id": "model.channel_join_request.is_valid.status.app_error", + "translation": "Invalid join request status." + }, + { + "id": "model.channel_join_request.is_valid.update_at.app_error", + "translation": "Update at must be a valid time." + }, + { + "id": "model.channel_join_request.is_valid.user_id.app_error", + "translation": "Invalid user id." + }, { "id": "model.channel_member.is_valid.channel_auto_follow_threads_value.app_error", "translation": "Invalid channel-auto-follow-threads value." diff --git a/server/public/model/channel.go b/server/public/model/channel.go index 5226f160c97..ef73d651fe8 100644 --- a/server/public/model/channel.go +++ b/server/public/model/channel.go @@ -108,6 +108,7 @@ type Channel struct { PolicyIsActive bool `json:"policy_is_active"` DefaultCategoryName string `json:"default_category_name"` ManagedCategoryName string `json:"managed_category_name"` + Discoverable bool `json:"discoverable"` } func (o *Channel) Auditable() map[string]any { @@ -131,6 +132,7 @@ func (o *Channel) Auditable() map[string]any { "policy_enforced": o.PolicyEnforced, "autotranslation": o.AutoTranslation, "policy_is_active": o.PolicyIsActive, // this field is only for logging purposes + "discoverable": o.Discoverable, } } @@ -160,6 +162,7 @@ type ChannelPatch struct { AutoTranslation *bool `json:"autotranslation"` ManagedCategoryName *string `json:"managed_category_name"` DefaultCategoryName *string `json:"default_category_name"` + Discoverable *bool `json:"discoverable"` } func (c *ChannelPatch) Auditable() map[string]any { @@ -169,6 +172,7 @@ func (c *ChannelPatch) Auditable() map[string]any { "purpose": c.Purpose, "default_category_name": c.DefaultCategoryName, "managed_category_name": c.ManagedCategoryName, + "discoverable": c.Discoverable, } } @@ -339,6 +343,10 @@ func (o *Channel) IsValid() *AppError { } } + if o.Discoverable && o.Type != ChannelTypePrivate { + return NewAppError("Channel.IsValid", "model.channel.is_valid.discoverable.app_error", nil, "id="+o.Id, http.StatusBadRequest) + } + return nil } @@ -459,6 +467,10 @@ func (o *Channel) Patch(patch *ChannelPatch) { if patch.DefaultCategoryName != nil { o.DefaultCategoryName = strings.TrimSpace(*patch.DefaultCategoryName) } + + if patch.Discoverable != nil { + o.Discoverable = *patch.Discoverable + } } func (o *Channel) MakeNonNil() { diff --git a/server/public/model/channel_join_request.go b/server/public/model/channel_join_request.go new file mode 100644 index 00000000000..38c0e6248dc --- /dev/null +++ b/server/public/model/channel_join_request.go @@ -0,0 +1,165 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "net/http" + "unicode/utf8" +) + +const ( + ChannelJoinRequestStatusPending = "pending" + ChannelJoinRequestStatusApproved = "approved" + ChannelJoinRequestStatusDenied = "denied" + ChannelJoinRequestStatusWithdrawn = "withdrawn" + + ChannelJoinRequestMessageMaxRunes = 500 + ChannelJoinRequestDenialReasonMaxRunes = 500 +) + +// ChannelJoinRequest records a user's request to join a discoverable private channel. +// +// Rows are append-only / status-mutating: a request transitions through +// pending → approved | denied | withdrawn. Rows are never deleted so the full +// audit history is preserved. A partial unique index in Postgres enforces at +// most one active pending row per (ChannelId, UserId). +type ChannelJoinRequest struct { + Id string `json:"id"` + ChannelId string `json:"channel_id"` + UserId string `json:"user_id"` + Message string `json:"message"` + Status string `json:"status"` + DenialReason string `json:"denial_reason"` + CreateAt int64 `json:"create_at"` + UpdateAt int64 `json:"update_at"` + ReviewedBy string `json:"reviewed_by"` + ReviewedAt int64 `json:"reviewed_at"` +} + +// ChannelJoinRequestList is the paginated response shape returned by list endpoints. +type ChannelJoinRequestList struct { + Requests []*ChannelJoinRequest `json:"requests"` + TotalCount int64 `json:"total_count"` +} + +// ChannelJoinRequestPatch represents the admin review action: approve or deny, +// with an optional denial reason that is surfaced to the requester. +type ChannelJoinRequestPatch struct { + Status string `json:"status"` + DenialReason *string `json:"denial_reason,omitempty"` +} + +// GetChannelJoinRequestsOpts filters and paginates list queries on the store. +// An empty Status means "pending". +type GetChannelJoinRequestsOpts struct { + Status string + Page int + PerPage int +} + +// IsValidChannelJoinRequestStatus reports whether the given status string is a +// recognized lifecycle value for a ChannelJoinRequest. +func IsValidChannelJoinRequestStatus(s string) bool { + switch s { + case ChannelJoinRequestStatusPending, + ChannelJoinRequestStatusApproved, + ChannelJoinRequestStatusDenied, + ChannelJoinRequestStatusWithdrawn: + return true + } + return false +} + +func (r *ChannelJoinRequest) Auditable() map[string]any { + return map[string]any{ + "id": r.Id, + "channel_id": r.ChannelId, + "user_id": r.UserId, + "status": r.Status, + "create_at": r.CreateAt, + "update_at": r.UpdateAt, + "reviewed_by": r.ReviewedBy, + "reviewed_at": r.ReviewedAt, + "has_message": r.Message != "", + "has_denial_reason": r.DenialReason != "", + } +} + +func (r *ChannelJoinRequest) LogClone() any { + return r.Auditable() +} + +func (r *ChannelJoinRequest) IsValid() *AppError { + if !IsValidId(r.Id) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.id.app_error", nil, "", http.StatusBadRequest) + } + + if !IsValidId(r.ChannelId) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.channel_id.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if !IsValidId(r.UserId) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.user_id.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.CreateAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.create_at.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.UpdateAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.update_at.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if !IsValidChannelJoinRequestStatus(r.Status) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.status.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if utf8.RuneCountInString(r.Message) > ChannelJoinRequestMessageMaxRunes { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.message.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if utf8.RuneCountInString(r.DenialReason) > ChannelJoinRequestDenialReasonMaxRunes { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.denial_reason.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + // A denial reason is only meaningful on a denied request. + if r.DenialReason != "" && r.Status != ChannelJoinRequestStatusDenied { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.denial_reason_status.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + if r.ReviewedBy != "" && !IsValidId(r.ReviewedBy) { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.reviewed_by.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + + // Reviewer and reviewed-at must accompany a terminal review action. + switch r.Status { + case ChannelJoinRequestStatusApproved, ChannelJoinRequestStatusDenied: + if r.ReviewedBy == "" || r.ReviewedAt == 0 { + return NewAppError("ChannelJoinRequest.IsValid", "model.channel_join_request.is_valid.reviewer.app_error", nil, "id="+r.Id, http.StatusBadRequest) + } + } + + return nil +} + +func (r *ChannelJoinRequest) PreSave() { + if r.Id == "" { + r.Id = NewId() + } + if r.Status == "" { + r.Status = ChannelJoinRequestStatusPending + } + if r.CreateAt == 0 { + r.CreateAt = GetMillis() + } + r.UpdateAt = r.CreateAt + r.Message = SanitizeUnicode(r.Message) + r.DenialReason = SanitizeUnicode(r.DenialReason) +} + +func (r *ChannelJoinRequest) PreUpdate() { + r.UpdateAt = GetMillis() + r.Message = SanitizeUnicode(r.Message) + r.DenialReason = SanitizeUnicode(r.DenialReason) +} diff --git a/server/public/model/channel_join_request_test.go b/server/public/model/channel_join_request_test.go new file mode 100644 index 00000000000..78f354732c4 --- /dev/null +++ b/server/public/model/channel_join_request_test.go @@ -0,0 +1,114 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package model + +import ( + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func validRequest() *ChannelJoinRequest { + return &ChannelJoinRequest{ + Id: NewId(), + ChannelId: NewId(), + UserId: NewId(), + Status: ChannelJoinRequestStatusPending, + CreateAt: GetMillis(), + UpdateAt: GetMillis(), + } +} + +func TestChannelJoinRequestPreSaveDefaults(t *testing.T) { + r := &ChannelJoinRequest{ + ChannelId: NewId(), + UserId: NewId(), + } + r.PreSave() + + assert.NotEmpty(t, r.Id, "PreSave must assign an Id when missing") + assert.Equal(t, ChannelJoinRequestStatusPending, r.Status, "PreSave must default Status to pending") + assert.NotZero(t, r.CreateAt) + assert.Equal(t, r.CreateAt, r.UpdateAt, "PreSave must align UpdateAt with CreateAt") +} + +func TestChannelJoinRequestPreUpdateAdvancesUpdateAt(t *testing.T) { + r := validRequest() + originalCreate := r.CreateAt + // Seed UpdateAt to a known-old value so we can prove PreUpdate actually + // advanced it (the validRequest factory sets UpdateAt = GetMillis(), so + // a no-op PreUpdate could otherwise still pass a GreaterOrEqual check). + r.UpdateAt = 1 + r.PreUpdate() + + assert.Greater(t, r.UpdateAt, int64(1)) + assert.Equal(t, originalCreate, r.CreateAt, "PreUpdate must not mutate CreateAt") +} + +func TestChannelJoinRequestIsValid(t *testing.T) { + t.Run("happy path pending", func(t *testing.T) { + require.Nil(t, validRequest().IsValid()) + }) + + t.Run("invalid id", func(t *testing.T) { + r := validRequest() + r.Id = "not-an-id" + err := r.IsValid() + require.NotNil(t, err) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + }) + + t.Run("rejects unknown status", func(t *testing.T) { + r := validRequest() + r.Status = "weird" + require.NotNil(t, r.IsValid()) + }) + + t.Run("rejects message over rune limit", func(t *testing.T) { + r := validRequest() + r.Message = strings.Repeat("a", ChannelJoinRequestMessageMaxRunes+1) + require.NotNil(t, r.IsValid()) + }) + + t.Run("rejects denial reason on non-denied request", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusApproved + r.ReviewedBy = NewId() + r.ReviewedAt = GetMillis() + r.DenialReason = "nope" + require.NotNil(t, r.IsValid(), "denial reason must only be set on denied rows") + }) + + t.Run("requires reviewer info for terminal review", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusApproved + require.NotNil(t, r.IsValid(), "approved without reviewer must be invalid") + + r.ReviewedBy = NewId() + r.ReviewedAt = GetMillis() + require.Nil(t, r.IsValid()) + }) + + t.Run("withdrawn does not require reviewer", func(t *testing.T) { + r := validRequest() + r.Status = ChannelJoinRequestStatusWithdrawn + require.Nil(t, r.IsValid(), "withdrawn is a self-service action, not a review") + }) +} + +func TestIsValidChannelJoinRequestStatus(t *testing.T) { + for _, s := range []string{ + ChannelJoinRequestStatusPending, + ChannelJoinRequestStatusApproved, + ChannelJoinRequestStatusDenied, + ChannelJoinRequestStatusWithdrawn, + } { + assert.True(t, IsValidChannelJoinRequestStatus(s), "%q should be a valid status", s) + } + assert.False(t, IsValidChannelJoinRequestStatus("")) + assert.False(t, IsValidChannelJoinRequestStatus("approved ")) +} diff --git a/server/public/model/channel_test.go b/server/public/model/channel_test.go index 8d7a3b2ad09..21143f3b7f5 100644 --- a/server/public/model/channel_test.go +++ b/server/public/model/channel_test.go @@ -35,6 +35,64 @@ func TestChannelPatch(t *testing.T) { require.Equal(t, *p.GroupConstrained, *o.GroupConstrained) } +func TestChannelPatchDiscoverable(t *testing.T) { + t.Run("applies discoverable when set", func(t *testing.T) { + on := true + p := &ChannelPatch{Discoverable: &on} + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate} + o.Patch(p) + require.True(t, o.Discoverable) + }) + + t.Run("clears discoverable when set to false", func(t *testing.T) { + off := false + p := &ChannelPatch{Discoverable: &off} + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate, Discoverable: true} + o.Patch(p) + require.False(t, o.Discoverable) + }) + + t.Run("nil discoverable leaves channel untouched", func(t *testing.T) { + o := Channel{Id: NewId(), Name: NewId(), Type: ChannelTypePrivate, Discoverable: true} + o.Patch(&ChannelPatch{}) + require.True(t, o.Discoverable) + }) +} + +func TestChannelIsValidDiscoverable(t *testing.T) { + base := Channel{ + Id: NewId(), + CreateAt: GetMillis(), + UpdateAt: GetMillis(), + DisplayName: "x", + Name: "valid-name", + Header: "h", + Purpose: "p", + } + + t.Run("discoverable=false is valid on any type", func(t *testing.T) { + c := base + c.Type = ChannelTypeOpen + require.Nil(t, c.IsValid()) + }) + + t.Run("discoverable=true requires private channel", func(t *testing.T) { + c := base + c.Type = ChannelTypeOpen + c.Discoverable = true + require.NotNil(t, c.IsValid(), "discoverable=true on public channel must be rejected") + + c.Type = ChannelTypeDirect + require.NotNil(t, c.IsValid()) + + c.Type = ChannelTypeGroup + require.NotNil(t, c.IsValid()) + + c.Type = ChannelTypePrivate + require.Nil(t, c.IsValid()) + }) +} + func TestChannelIsValid(t *testing.T) { o := Channel{} diff --git a/server/public/model/feature_flags.go b/server/public/model/feature_flags.go index ac8d06a3092..68465f32025 100644 --- a/server/public/model/feature_flags.go +++ b/server/public/model/feature_flags.go @@ -120,6 +120,11 @@ type FeatureFlags struct { // ManagedChannelCategories enables server-side managed sidebar category enforcement (Enterprise). ManagedChannelCategories bool + + // FEATURE_FLAG_REMOVAL: DiscoverableChannels - Remove this when the feature is GA. + // Gates the per-channel Discoverable toggle and the channel-join-request flow that lets + // non-members find a private channel in Browse Channels and request to join it. + DiscoverableChannels bool } func (f *FeatureFlags) SetDefaults() { @@ -176,6 +181,8 @@ func (f *FeatureFlags) SetDefaults() { f.AggregatePluginMetrics = false f.ManagedChannelCategories = false + + f.DiscoverableChannels = false } // ToMap returns the feature flags as a map[string]string diff --git a/server/public/model/migration.go b/server/public/model/migration.go index f29087a19f5..7ece8d257ba 100644 --- a/server/public/model/migration.go +++ b/server/public/model/migration.go @@ -65,4 +65,5 @@ const ( MigrationKeyAccessControlPolicyV0_3 = "access_control_policy_v0_3_migration" MigrationKeyAddManageAgentPermissions = "add_manage_agent_permissions" MigrationKeyAddEditFileAttachmentPermission = "add_edit_file_attachment_permission" + MigrationKeyAddDiscoverableChannelPermissions = "add_discoverable_channel_permissions" ) diff --git a/server/public/model/permission.go b/server/public/model/permission.go index 63ab17e5369..921d46ff894 100644 --- a/server/public/model/permission.go +++ b/server/public/model/permission.go @@ -49,6 +49,8 @@ var PermissionManagePublicChannelProperties *Permission var PermissionManagePrivateChannelProperties *Permission var PermissionManagePublicChannelAutoTranslation *Permission var PermissionManagePrivateChannelAutoTranslation *Permission +var PermissionManagePrivateChannelDiscoverability *Permission +var PermissionManageChannelJoinRequests *Permission var PermissionListPublicTeams *Permission var PermissionJoinPublicTeams *Permission var PermissionListPrivateTeams *Permission @@ -568,6 +570,18 @@ func initializePermissions() { "authentication.permissions.manage_private_channel_auto_translation.description", PermissionScopeChannel, } + PermissionManagePrivateChannelDiscoverability = &Permission{ + "manage_private_channel_discoverability", + "authentication.permissions.manage_private_channel_discoverability.name", + "authentication.permissions.manage_private_channel_discoverability.description", + PermissionScopeChannel, + } + PermissionManageChannelJoinRequests = &Permission{ + "manage_channel_join_requests", + "authentication.permissions.manage_channel_join_requests.name", + "authentication.permissions.manage_channel_join_requests.description", + PermissionScopeChannel, + } PermissionListPublicTeams = &Permission{ "list_public_teams", "authentication.permissions.list_public_teams.name", @@ -2631,6 +2645,8 @@ func initializePermissions() { PermissionManagePrivateChannelBanner, PermissionManageChannelAccessRules, PermissionEditFileAttachment, + PermissionManagePrivateChannelDiscoverability, + PermissionManageChannelJoinRequests, } GroupScopedPermissions := []*Permission{ diff --git a/server/public/model/role.go b/server/public/model/role.go index 0513f565c49..17f2807f3d7 100644 --- a/server/public/model/role.go +++ b/server/public/model/role.go @@ -932,6 +932,8 @@ func MakeDefaultRoles() map[string]*Role { PermissionManageChannelAccessRules.Id, PermissionManagePublicChannelAutoTranslation.Id, PermissionManagePrivateChannelAutoTranslation.Id, + PermissionManagePrivateChannelDiscoverability.Id, + PermissionManageChannelJoinRequests.Id, }, SchemeManaged: true, BuiltIn: true, diff --git a/server/public/model/websocket_message.go b/server/public/model/websocket_message.go index c816d3234a7..87f9ead3544 100644 --- a/server/public/model/websocket_message.go +++ b/server/public/model/websocket_message.go @@ -117,6 +117,8 @@ const ( WebsocketEventFileDownloadRejected WebsocketEventType = "file_download_rejected" WebsocketEventShowToast WebsocketEventType = "show_toast" WebsocketEventSharedChannelRemoteUpdated WebsocketEventType = "shared_channel_remote_updated" + WebsocketEventChannelJoinRequestCreated WebsocketEventType = "channel_join_request_created" + WebsocketEventChannelJoinRequestUpdated WebsocketEventType = "channel_join_request_updated" WebSocketMsgTypeResponse = "response" WebSocketMsgTypeEvent = "event" diff --git a/webapp/platform/client/src/websocket_events.ts b/webapp/platform/client/src/websocket_events.ts index 8683efca91c..56bf0179e8a 100644 --- a/webapp/platform/client/src/websocket_events.ts +++ b/webapp/platform/client/src/websocket_events.ts @@ -99,4 +99,6 @@ export const enum WebSocketEvents { FileDownloadRejected = 'file_download_rejected', ShowToast = 'show_toast', SharedChannelRemoteUpdated = 'shared_channel_remote_updated', + ChannelJoinRequestCreated = 'channel_join_request_created', + ChannelJoinRequestUpdated = 'channel_join_request_updated', } diff --git a/webapp/platform/types/src/channels.ts b/webapp/platform/types/src/channels.ts index 7252122b15c..0f7253d5b37 100644 --- a/webapp/platform/types/src/channels.ts +++ b/webapp/platform/types/src/channels.ts @@ -73,6 +73,7 @@ export type Channel = { default_category_name?: string; managed_category_name?: string; autotranslation?: boolean; + discoverable?: boolean; }; export type ServerChannel = Channel & { @@ -112,6 +113,32 @@ export type ChannelsWithTotalCount = { total_count: number; }; +export type ChannelJoinRequestStatus = 'pending' | 'approved' | 'denied' | 'withdrawn'; + +export type ChannelJoinRequest = { + id: string; + channel_id: string; + user_id: string; + message: string; + status: ChannelJoinRequestStatus; + denial_reason: string; + create_at: number; + update_at: number; + reviewed_by: string; + reviewed_at: number; +}; + +export type ChannelJoinRequestList = { + requests: ChannelJoinRequest[]; + total_count: number; +}; + +export type GetChannelJoinRequestsOptions = { + status?: ChannelJoinRequestStatus; + page?: number; + per_page?: number; +}; + export type ChannelMembership = { channel_id: string; user_id: string; From 02023f0328e5bb9e04ca06f25c2284efbc1f1759 Mon Sep 17 00:00:00 2001 From: Ben Cooke Date: Fri, 15 May 2026 15:26:03 -0400 Subject: [PATCH 15/80] [MM-68463] New endpoint to GET user by auth_data (#36352) --- api/v4/source/users.yaml | 48 ++++++ server/channels/api4/user.go | 56 ++++++ server/channels/api4/user_local.go | 41 +++++ server/channels/api4/user_test.go | 163 ++++++++++++++++++ server/channels/app/user.go | 18 ++ server/channels/app/users/users.go | 4 + .../channels/store/retrylayer/retrylayer.go | 21 +++ server/channels/store/sqlstore/user_store.go | 21 +++ server/channels/store/store.go | 1 + .../store/storetest/mocks/UserStore.go | 30 ++++ server/channels/store/storetest/user_store.go | 120 ++++++++++++- .../channels/store/timerlayer/timerlayer.go | 16 ++ server/public/model/client4.go | 12 ++ .../apiAuditLogs/whitelist.go | 2 + 14 files changed, 547 insertions(+), 6 deletions(-) diff --git a/api/v4/source/users.yaml b/api/v4/source/users.yaml index c22441ec34d..e5c6b6de375 100644 --- a/api/v4/source/users.yaml +++ b/api/v4/source/users.yaml @@ -1534,6 +1534,54 @@ $ref: "#/components/responses/Unauthorized" "404": $ref: "#/components/responses/NotFound" + /api/v4/users/auth_data: + get: + tags: + - users + summary: Get a user by auth data + description: > + Get a user by their external auth data identifier. The `value` is + matched against what is stored in `Users.AuthData`, which for most + identity providers is the identifier as the provider issues it. + + + The exception is Active Directory `objectGUID`: under + `auth_service: ldap` it is stored as the LDAP filter hex-escape + form (e.g. `\61\14\e1\d1\c5\35\18\4a\b6\60\d6\78\50\fd\0d\5d`), + and under `auth_service: saml` it is stored as the standard + Base64 of the same bytes (e.g. `YRTh0cU1GEq2YNZ4UP0NXQ==`). Use + the form matching the user's current `AuthService`. + + + ##### Permissions + + Must be a system admin. + operationId: GetUserByAuthData + parameters: + - name: value + in: query + description: > + The user's AuthData as stored in `Users.AuthData`. Must be + URL-encoded; in particular, Base64 `+` characters must be sent + as `%2B` so they are not decoded as spaces. + required: true + schema: + type: string + responses: + "200": + description: User retrieval successful + content: + application/json: + schema: + $ref: "#/components/schemas/User" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + "404": + $ref: "#/components/responses/NotFound" /api/v4/users/password/reset: post: tags: diff --git a/server/channels/api4/user.go b/server/channels/api4/user.go index aa55372b3bd..7a5aec2426d 100644 --- a/server/channels/api4/user.go +++ b/server/channels/api4/user.go @@ -76,6 +76,7 @@ func (api *API) InitUser() { api.BaseRoutes.UserByUsername.Handle("", api.APISessionRequired(getUserByUsername)).Methods(http.MethodGet) api.BaseRoutes.UserByEmail.Handle("", api.APISessionRequired(getUserByEmail)).Methods(http.MethodGet) + api.BaseRoutes.Users.Handle("/auth_data", api.APISessionRequired(getUserByAuthData)).Methods(http.MethodGet) api.BaseRoutes.User.Handle("/sessions", api.APISessionRequired(getSessions)).Methods(http.MethodGet) api.BaseRoutes.User.Handle("/sessions/revoke", api.APISessionRequired(revokeSession)).Methods(http.MethodPost) @@ -461,6 +462,61 @@ func getUserByEmail(c *Context, w http.ResponseWriter, r *http.Request) { } } +func getUserByAuthData(c *Context, w http.ResponseWriter, r *http.Request) { + if !c.IsSystemAdmin() { + c.SetPermissionError(model.PermissionManageSystem) + return + } + authData := r.URL.Query().Get("value") + if authData == "" { + c.SetInvalidParam("value") + return + } + if len(authData) > model.UserAuthDataMaxLength { + c.SetInvalidParam("value") + return + } + user, err := c.App.GetUserByAuthData(&authData) + if err != nil { + c.Err = err + return + } + + canSee, err2 := c.App.UserCanSeeOtherUser(c.AppContext, c.AppContext.Session().UserId, user.Id) + if err2 != nil { + c.Err = err2 + return + } + + if !canSee { + c.SetPermissionError(model.PermissionViewMembers) + return + } + + userTermsOfService, err := c.App.GetUserTermsOfService(user.Id) + if err != nil && err.StatusCode != http.StatusNotFound { + c.Err = err + return + } + + if userTermsOfService != nil { + user.TermsOfServiceId = userTermsOfService.TermsOfServiceId + user.TermsOfServiceCreateAt = userTermsOfService.CreateAt + } + + etag := user.Etag(*c.App.Config().PrivacySettings.ShowFullName, *c.App.Config().PrivacySettings.ShowEmailAddress) + + if c.HandleEtag(etag, "Get User", w, r) { + return + } + + c.App.SanitizeProfile(user, c.IsSystemAdmin()) + w.Header().Set(model.HeaderEtagServer, etag) + if jerr := json.NewEncoder(w).Encode(user); jerr != nil { + c.Logger.Warn("Error while writing response", mlog.Err(jerr)) + } +} + func getDefaultProfileImage(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireUserId() if c.Err != nil { diff --git a/server/channels/api4/user_local.go b/server/channels/api4/user_local.go index 7fba2cb46cc..2747a92d619 100644 --- a/server/channels/api4/user_local.go +++ b/server/channels/api4/user_local.go @@ -40,6 +40,7 @@ func (api *API) InitUserLocal() { api.BaseRoutes.UserByUsername.Handle("", api.APILocal(localGetUserByUsername)).Methods(http.MethodGet) api.BaseRoutes.UserByEmail.Handle("", api.APILocal(localGetUserByEmail)).Methods(http.MethodGet) + api.BaseRoutes.Users.Handle("/auth_data", api.APILocal(localGetUserByAuthData)).Methods(http.MethodGet) api.BaseRoutes.Users.Handle("/tokens/revoke", api.APILocal(revokeUserAccessToken)).Methods(http.MethodPost) api.BaseRoutes.User.Handle("/tokens", api.APILocal(getUserAccessTokensForUser)).Methods(http.MethodGet) @@ -427,6 +428,46 @@ func localGetUserByEmail(c *Context, w http.ResponseWriter, r *http.Request) { } } +func localGetUserByAuthData(c *Context, w http.ResponseWriter, r *http.Request) { + authData := r.URL.Query().Get("value") + if authData == "" { + c.SetInvalidParam("value") + return + } + if len(authData) > model.UserAuthDataMaxLength { + c.SetInvalidParam("value") + return + } + user, err := c.App.GetUserByAuthData(&authData) + if err != nil { + c.Err = err + return + } + + userTermsOfService, err := c.App.GetUserTermsOfService(user.Id) + if err != nil && err.StatusCode != http.StatusNotFound { + c.Err = err + return + } + + if userTermsOfService != nil { + user.TermsOfServiceId = userTermsOfService.TermsOfServiceId + user.TermsOfServiceCreateAt = userTermsOfService.CreateAt + } + + etag := user.Etag(*c.App.Config().PrivacySettings.ShowFullName, *c.App.Config().PrivacySettings.ShowEmailAddress) + + if c.HandleEtag(etag, "Get User", w, r) { + return + } + + c.App.SanitizeProfile(user, c.IsSystemAdmin()) + w.Header().Set(model.HeaderEtagServer, etag) + if jerr := json.NewEncoder(w).Encode(user); jerr != nil { + c.Logger.Warn("Error while writing response", mlog.Err(jerr)) + } +} + func localGetUploadsForUser(c *Context, w http.ResponseWriter, r *http.Request) { uss, appErr := c.App.GetUploadSessionsForUser(c.Params.UserId) if appErr != nil { diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index 82afeea8db1..127aa75a9ed 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -1487,6 +1487,169 @@ func TestGetUserByEmail(t *testing.T) { }) } +func TestGetUserByAuthData(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t) + + team := th.CreateTeamWithClient(t, th.SystemAdminClient) + regularUser := th.CreateUser(t) + th.LinkUserToTeam(t, regularUser, team) + user := th.CreateUser(t) + th.LinkUserToTeam(t, user, team) + _, err := th.App.Srv().Store().User().VerifyEmail(user.Id, user.Email) + require.NoError(t, err) + + authID := "extid-" + model.NewId() + userAuth := &model.UserAuth{ + AuthData: model.NewPointer(authID), + AuthService: model.UserAuthServiceSaml, + } + _, _, err = th.SystemAdminClient.UpdateUserAuth(context.Background(), user.Id, userAuth) + require.NoError(t, err) + + th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { + t.Run("returns user and auth fields for system admin and local", func(t *testing.T) { + ruser, resp, getErr := client.GetUserByAuthData(context.Background(), authID, "") + require.NoError(t, getErr) + require.Equal(t, user.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, authID, *ruser.AuthData) + require.Equal(t, model.UserAuthServiceSaml, ruser.AuthService) + ruser, resp, _ = client.GetUserByAuthData(context.Background(), authID, resp.Etag) + CheckEtag(t, ruser, resp) + }) + + t.Run("not found returns not found", func(t *testing.T) { + _, resp, notFoundErr := client.GetUserByAuthData(context.Background(), "nope-"+model.NewId(), "") + require.Error(t, notFoundErr) + CheckNotFoundStatus(t, resp) + }) + }) + + t.Run("returns accepted terms of service for system admin", func(t *testing.T) { + tos, appErr := th.App.CreateTermsOfService("Dummy TOS", user.Id) + require.Nil(t, appErr) + appErr = th.App.SaveUserTermsOfService(user.Id, tos.Id, true) + require.Nil(t, appErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), authID, "") + require.NoError(t, getErr) + require.Equal(t, tos.Id, ruser.TermsOfServiceId, "Terms of service ID should match") + require.NotZero(t, ruser.TermsOfServiceCreateAt, "Terms of service CreateAt should be populated") + }) + + t.Run("returns user when auth_data is an email-shaped value", func(t *testing.T) { + // ResetAuthDataToEmailForUsers sets AuthData = Email for whole batches of + // users, so email-shaped auth_data values are common in practice. Verify + // the route, Client4 path escaping (`@` -> `%40`), and server-side decoding + // all round-trip correctly. + emailUser := th.CreateUser(t) + th.LinkUserToTeam(t, emailUser, team) + emailAuth := "user-" + model.NewId() + "@example.com" + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), emailUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(emailAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), emailAuth, "") + require.NoError(t, getErr) + require.Equal(t, emailUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, emailAuth, *ruser.AuthData) + }) + + t.Run("preserves case in auth_data", func(t *testing.T) { + // auth_data is opaque and case-sensitive (unlike email, which the email + // endpoint lowercases). Non-SAML IdPs commonly issue mixed-case identifiers, + // so guard against a regression where the handler normalizes the value. + mixedUser := th.CreateUser(t) + th.LinkUserToTeam(t, mixedUser, team) + mixedAuth := "MixedCase-" + model.NewId() + "@Example.COM" + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), mixedUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(mixedAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), mixedAuth, "") + require.NoError(t, getErr) + require.Equal(t, mixedUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, mixedAuth, *ruser.AuthData) + + _, resp, lowerErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), strings.ToLower(mixedAuth), "") + require.Error(t, lowerErr) + CheckNotFoundStatus(t, resp) + }) + + t.Run("returns user when auth_data is an LDAP objectGUID hex-escape form", func(t *testing.T) { + // AD objectGUID stored under auth_service=ldap uses the LDAP filter + // hex-escape form (`\xx` per byte). Backslashes are special in URL paths + // (WHATWG rewrites them to `/`), which is why this endpoint uses a query + // parameter; this test guards the query-string round-trip for the exact + // shape the customer reported. + ldapUser := th.CreateUser(t) + th.LinkUserToTeam(t, ldapUser, team) + ldapAuth := `\61\14\e1\d1\c5\35\18\4a\b6\60\d6\78\50\fd\0d\5d` + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), ldapUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(ldapAuth), + AuthService: model.UserAuthServiceLdap, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), ldapAuth, "") + require.NoError(t, getErr) + require.Equal(t, ldapUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, ldapAuth, *ruser.AuthData) + }) + + t.Run("returns user when auth_data is SAML base64 with reserved chars", func(t *testing.T) { + // AD objectGUID stored under auth_service=saml uses standard Base64, + // which can contain `+`, `/`, and `=` padding -- all reserved in + // application/x-www-form-urlencoded. url.Values.Set escapes them + // correctly; this test guards against a future regression where someone + // rewrites the client to skip that escaping. + samlUser := th.CreateUser(t) + th.LinkUserToTeam(t, samlUser, team) + // Bytes chosen to produce all three reserved characters in the Base64 + // output: 0xfb,0xef,0xff,0x00 -> "++//AA==". + samlAuth := base64.StdEncoding.EncodeToString([]byte{0xfb, 0xef, 0xff, 0x00}) + require.Contains(t, samlAuth, "+") + require.Contains(t, samlAuth, "/") + require.Contains(t, samlAuth, "=") + _, _, updErr := th.SystemAdminClient.UpdateUserAuth(context.Background(), samlUser.Id, &model.UserAuth{ + AuthData: model.NewPointer(samlAuth), + AuthService: model.UserAuthServiceSaml, + }) + require.NoError(t, updErr) + + ruser, _, getErr := th.SystemAdminClient.GetUserByAuthData(context.Background(), samlAuth, "") + require.NoError(t, getErr) + require.Equal(t, samlUser.Id, ruser.Id) + require.NotNil(t, ruser.AuthData) + require.Equal(t, samlAuth, *ruser.AuthData) + }) + + t.Run("rejects non-system admin", func(t *testing.T) { + // `user` is converted to SAML below and can no longer use password login; use + // a separate team member to assert the endpoint requires a system admin. + _, _, err = th.Client.Login(context.Background(), regularUser.Email, regularUser.Password) + require.NoError(t, err) + _, resp, err := th.Client.GetUserByAuthData(context.Background(), authID, "") + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("rejects auth data over max length", func(t *testing.T) { + longData := strings.Repeat("x", model.UserAuthDataMaxLength+1) + _, resp, err := th.SystemAdminClient.GetUserByAuthData(context.Background(), longData, "") + require.Error(t, err) + CheckBadRequestStatus(t, resp) + }) +} + // This test can flake if two calls to model.NewId can return the same value. // Not much can be done about it. func TestSearchUsers(t *testing.T) { diff --git a/server/channels/app/user.go b/server/channels/app/user.go index 932b039dd0c..325e3bdd1a1 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -606,6 +606,24 @@ func (a *App) GetUserByAuth(authData *string, authService string) (*model.User, return user, nil } +func (a *App) GetUserByAuthData(authData *string) (*model.User, *model.AppError) { + user, err := a.ch.srv.userService.GetUserByAuthData(authData) + if err != nil { + var invErr *store.ErrInvalidInput + var nfErr *store.ErrNotFound + switch { + case errors.As(err, &invErr): + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusBadRequest).Wrap(err) + case errors.As(err, &nfErr): + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusNotFound).Wrap(err) + default: + return nil, model.NewAppError("GetUserByAuthData", MissingAccountError, nil, "", http.StatusInternalServerError).Wrap(err) + } + } + + return user, nil +} + func (a *App) GetUsersFromProfiles(options *model.UserGetOptions) ([]*model.User, *model.AppError) { users, err := a.ch.srv.userService.GetUsersFromProfiles(options) if err != nil { diff --git a/server/channels/app/users/users.go b/server/channels/app/users/users.go index 900184f781e..c81c0f8ccec 100644 --- a/server/channels/app/users/users.go +++ b/server/channels/app/users/users.go @@ -114,6 +114,10 @@ func (us *UserService) GetUserByAuth(authData *string, authService string) (*mod return us.store.GetByAuth(authData, authService) } +func (us *UserService) GetUserByAuthData(authData *string) (*model.User, error) { + return us.store.GetByAuthData(authData) +} + func (us *UserService) GetUsersFromProfiles(options *model.UserGetOptions) ([]*model.User, error) { return us.store.GetAllProfiles(options) } diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 61970c998ee..938cd91d1c7 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -16306,6 +16306,27 @@ func (s *RetryLayerUserStore) GetByAuth(authData *string, authService string) (* } +func (s *RetryLayerUserStore) GetByAuthData(authData *string) (*model.User, error) { + + tries := 0 + for { + result, err := s.UserStore.GetByAuthData(authData) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerUserStore) GetByEmail(email string) (*model.User, error) { tries := 0 diff --git a/server/channels/store/sqlstore/user_store.go b/server/channels/store/sqlstore/user_store.go index 7b11d1220e5..524bc917373 100644 --- a/server/channels/store/sqlstore/user_store.go +++ b/server/channels/store/sqlstore/user_store.go @@ -1281,6 +1281,27 @@ func (us SqlUserStore) GetByRemoteID(remoteID string) (*model.User, error) { return &user, nil } +func (us SqlUserStore) GetByAuthData(authData *string) (*model.User, error) { + if authData == nil || *authData == "" { + return nil, store.NewErrInvalidInput("User", "", "empty or nil") + } + + query := us.usersQuery.Where("Users.AuthData = ?", authData) + + queryString, args, err := query.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_by_auth_data_tosql") + } + + user := model.User{} + if err := us.GetReplica().Get(&user, queryString, args...); err == sql.ErrNoRows { + return nil, store.NewErrNotFound("User", fmt.Sprintf("authData=%s", *authData)) + } else if err != nil { + return nil, errors.Wrapf(err, "failed to find User with authData=%s", *authData) + } + return &user, nil +} + func (us SqlUserStore) GetByAuth(authData *string, authService string) (*model.User, error) { if authData == nil || *authData == "" { return nil, store.NewErrInvalidInput("User", "", "empty or nil") diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 331595c3dc1..596af20bbeb 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -473,6 +473,7 @@ type UserStore interface { GetByEmail(email string) (*model.User, error) GetByRemoteID(remoteID string) (*model.User, error) GetByAuth(authData *string, authService string) (*model.User, error) + GetByAuthData(authData *string) (*model.User, error) GetAllUsingAuthService(authService string) ([]*model.User, error) GetAllNotInAuthService(authServices []string) ([]*model.User, error) GetByUsername(username string) (*model.User, error) diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index 0845e1e23ab..12026e7f4b6 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -655,6 +655,36 @@ func (_m *UserStore) GetByAuth(authData *string, authService string) (*model.Use return r0, r1 } +// GetByAuthData provides a mock function with given fields: authData +func (_m *UserStore) GetByAuthData(authData *string) (*model.User, error) { + ret := _m.Called(authData) + + if len(ret) == 0 { + panic("no return value specified for GetByAuthData") + } + + var r0 *model.User + var r1 error + if rf, ok := ret.Get(0).(func(*string) (*model.User, error)); ok { + return rf(authData) + } + if rf, ok := ret.Get(0).(func(*string) *model.User); ok { + r0 = rf(authData) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.User) + } + } + + if rf, ok := ret.Get(1).(func(*string) error); ok { + r1 = rf(authData) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetByEmail provides a mock function with given fields: email func (_m *UserStore) GetByEmail(email string) (*model.User, error) { ret := _m.Called(email) diff --git a/server/channels/store/storetest/user_store.go b/server/channels/store/storetest/user_store.go index ec303b8c6c5..da4ef47f1a2 100644 --- a/server/channels/store/storetest/user_store.go +++ b/server/channels/store/storetest/user_store.go @@ -69,6 +69,7 @@ func TestUserStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) { t.Run("GetProfilesByUsernames", func(t *testing.T) { testUserStoreGetProfilesByUsernames(t, rctx, ss) }) t.Run("GetSystemAdminProfiles", func(t *testing.T) { testUserStoreGetSystemAdminProfiles(t, rctx, ss) }) t.Run("GetByEmail", func(t *testing.T) { testUserStoreGetByEmail(t, rctx, ss) }) + t.Run("GetByAuth", func(t *testing.T) { testUserStoreGetByAuth(t, rctx, ss) }) t.Run("GetByAuthData", func(t *testing.T) { testUserStoreGetByAuthData(t, rctx, ss) }) t.Run("GetByUsername", func(t *testing.T) { testUserStoreGetByUsername(t, rctx, ss) }) t.Run("GetForLogin", func(t *testing.T) { testUserStoreGetForLogin(t, rctx, ss) }) @@ -2104,7 +2105,7 @@ func testUserStoreGetByEmail(t *testing.T, rctx request.CTX, ss store.Store) { }) } -func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) { +func testUserStoreGetByAuth(t *testing.T, rctx request.CTX, ss store.Store) { teamID := model.NewId() auth1 := model.NewId() auth3 := model.NewId() @@ -2167,21 +2168,128 @@ func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) require.True(t, errors.As(err, &nfErr)) }) - t.Run("get by unknown auth, u1 service", func(t *testing.T) { - unknownAuth := "" + t.Run("get by unknown non-empty auth, u1 service", func(t *testing.T) { + unknownAuth := model.NewId() _, err := ss.User().GetByAuth(&unknownAuth, u1.AuthService) require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("get by empty auth, u1 service", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuth(&emptyAuth, u1.AuthService) + require.Error(t, err) var invErr *store.ErrInvalidInput require.True(t, errors.As(err, &invErr)) }) - t.Run("get by unknown auth, unknown service", func(t *testing.T) { - unknownAuth := "" - _, err := ss.User().GetByAuth(&unknownAuth, "unknown") + t.Run("get by nil auth, u1 service", func(t *testing.T) { + _, err := ss.User().GetByAuth(nil, u1.AuthService) require.Error(t, err) var invErr *store.ErrInvalidInput require.True(t, errors.As(err, &invErr)) }) + + t.Run("get by unknown non-empty auth, unknown service", func(t *testing.T) { + unknownAuth := model.NewId() + _, err := ss.User().GetByAuth(&unknownAuth, "unknown") + require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("get by empty auth, unknown service", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuth(&emptyAuth, "unknown") + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) +} + +func testUserStoreGetByAuthData(t *testing.T, rctx request.CTX, ss store.Store) { + teamID := model.NewId() + auth1 := model.NewId() + auth2 := model.NewId() + + u1, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u1" + model.NewId(), + AuthData: &auth1, + AuthService: "service", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u1.Id)) }() + _, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: teamID, UserId: u1.Id}, -1) + require.NoError(t, nErr) + + u2, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u2" + model.NewId(), + AuthData: &auth2, + AuthService: "service2", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u2.Id)) }() + _, nErr = ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: teamID, UserId: u2.Id}, -1) + require.NoError(t, nErr) + + t.Run("returns full user when auth data matches", func(t *testing.T) { + u, err := ss.User().GetByAuthData(u1.AuthData) + require.NoError(t, err) + assert.Equal(t, u1, u) + }) + + t.Run("matches regardless of auth service", func(t *testing.T) { + u, err := ss.User().GetByAuthData(u2.AuthData) + require.NoError(t, err) + assert.Equal(t, u2.Id, u.Id) + assert.Equal(t, "service2", u.AuthService) + }) + + t.Run("returns ErrNotFound for unknown auth data", func(t *testing.T) { + unknownAuth := model.NewId() + _, err := ss.User().GetByAuthData(&unknownAuth) + require.Error(t, err) + var nfErr *store.ErrNotFound + require.True(t, errors.As(err, &nfErr)) + }) + + t.Run("returns ErrInvalidInput for nil auth data", func(t *testing.T) { + _, err := ss.User().GetByAuthData(nil) + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) + + t.Run("returns ErrInvalidInput for empty auth data", func(t *testing.T) { + emptyAuth := "" + _, err := ss.User().GetByAuthData(&emptyAuth) + require.Error(t, err) + var invErr *store.ErrInvalidInput + require.True(t, errors.As(err, &invErr)) + }) + + t.Run("matches when auth data is an email-shaped value", func(t *testing.T) { + // ResetAuthDataToEmailForUsers sets AuthData = Email for whole batches of + // users, so email-shaped auth_data values are common in practice. + emailAuth := "u3-" + model.NewId() + "@example.com" + u3, err := ss.User().Save(rctx, &model.User{ + Email: MakeEmail(), + Username: "u3" + model.NewId(), + AuthData: &emailAuth, + AuthService: "service", + }) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u3.Id)) }() + + u, err := ss.User().GetByAuthData(&emailAuth) + require.NoError(t, err) + assert.Equal(t, u3.Id, u.Id) + require.NotNil(t, u.AuthData) + assert.Equal(t, emailAuth, *u.AuthData) + }) } func testUserStoreGetByUsername(t *testing.T, rctx request.CTX, ss store.Store) { diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index ced6f298f9d..b1b4e1ed635 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -12864,6 +12864,22 @@ func (s *TimerLayerUserStore) GetByAuth(authData *string, authService string) (* return result, err } +func (s *TimerLayerUserStore) GetByAuthData(authData *string) (*model.User, error) { + start := time.Now() + + result, err := s.UserStore.GetByAuthData(authData) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.GetByAuthData", success, elapsed) + } + return result, err +} + func (s *TimerLayerUserStore) GetByEmail(email string) (*model.User, error) { start := time.Now() diff --git a/server/public/model/client4.go b/server/public/model/client4.go index 179940aff96..c3eb8c76e43 100644 --- a/server/public/model/client4.go +++ b/server/public/model/client4.go @@ -1186,6 +1186,18 @@ func (c *Client4) GetUserByEmail(ctx context.Context, email, etag string) (*User return DecodeJSONFromResponse[*User](r) } +// GetUserByAuthData returns a user by auth_data (external AuthData). +func (c *Client4) GetUserByAuthData(ctx context.Context, authData, etag string) (*User, *Response, error) { + values := url.Values{} + values.Set("value", authData) + r, err := c.doAPIGetWithQuery(ctx, c.usersRoute().Join("auth_data"), values, etag) + if err != nil { + return nil, BuildResponse(r), err + } + defer closeBody(r) + return DecodeJSONFromResponse[*User](r) +} + // AutocompleteUsersInTeam returns the users on a team based on search term. func (c *Client4) AutocompleteUsersInTeam(ctx context.Context, teamId string, username string, limit int, etag string) (*UserAutocomplete, *Response, error) { values := url.Values{} diff --git a/tools/mattermost-govet/apiAuditLogs/whitelist.go b/tools/mattermost-govet/apiAuditLogs/whitelist.go index b79b3f6b7b5..9ec9502a30a 100644 --- a/tools/mattermost-govet/apiAuditLogs/whitelist.go +++ b/tools/mattermost-govet/apiAuditLogs/whitelist.go @@ -121,6 +121,7 @@ var whiteList = map[string]bool{ "getUserAccessToken": true, "getUserAccessTokens": true, "getUserAccessTokensForUser": true, + "getUserByAuthData": true, "getUserByEmail": true, "getUserByUsername": true, "getUsers": true, @@ -133,6 +134,7 @@ var whiteList = map[string]bool{ "getWebappPlugins": true, "listAutocompleteCommands": true, "listCommands": true, + "localGetUserByAuthData": true, "openDialog": true, "patchChannelModerations": true, "pinPost": true, From 8eb97fa6c3960aeb12cf17a9c526eb02e149d9b8 Mon Sep 17 00:00:00 2001 From: sabril <5334504+saturninoabril@users.noreply.github.com> Date: Sat, 16 May 2026 10:26:00 +0800 Subject: [PATCH 16/80] refactor: remove redundant status update jobs from E2E test workflows (#36579) * refactor: remove redundant status update jobs from E2E test workflows * refactor: rename context-name to commit-status-context in E2E test workflows --- .../e2e-tests-cypress-template-v2.yml | 73 ++----------------- .../e2e-tests-playwright-template-v2.yml | 73 ++----------------- 2 files changed, 12 insertions(+), 134 deletions(-) diff --git a/.github/workflows/e2e-tests-cypress-template-v2.yml b/.github/workflows/e2e-tests-cypress-template-v2.yml index ef82a6a09c8..3923c2c9e6a 100644 --- a/.github/workflows/e2e-tests-cypress-template-v2.yml +++ b/.github/workflows/e2e-tests-cypress-template-v2.yml @@ -130,35 +130,17 @@ permissions: contents: read statuses: write id-token: write - pull-requests: write env: SERVER_IMAGE: "${{ inputs.server_image_repo }}/${{ inputs.server_edition == 'fips' && 'mattermost-enterprise-fips-edition' || inputs.server_edition == 'team' && 'mattermost-team-edition' || 'mattermost-enterprise-edition' }}:${{ inputs.server_image_tag }}" jobs: - update-initial-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - steps: - - name: ci/set-initial-status - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: "tests running, image_tag:${{ inputs.server_image_tag }}${{ inputs.server_image_aliases && format(' ({0})', inputs.server_image_aliases) || '' }}" - status: pending - dispatch-begin: runs-on: ubuntu-24.04 permissions: contents: read id-token: write - pull-requests: write + statuses: write outputs: composite-identity-json: ${{ steps.composite-identity.outputs.composite-identity-json }} workers-matrix: ${{ steps.matrix.outputs.workers }} @@ -226,10 +208,10 @@ jobs: cypress-skip-on: ${{ inputs.cypress_skip_on }} cypress-sort-first: ${{ inputs.cypress_sort_first }} cypress-sort-last: ${{ inputs.cypress_sort_last }} - post-pr-comment: 'true' github-token: ${{ secrets.GITHUB_TOKEN }} - context-name: ${{ inputs.context_name }} - server-image: ${{ env.SERVER_IMAGE }} + commit-status-context: ${{ inputs.context_name }} + image-tag: ${{ inputs.server_image_tag }} + image-aliases: ${{ inputs.server_image_aliases }} workers: name: dispatch-run-${{ matrix.worker_index }} @@ -329,7 +311,7 @@ jobs: permissions: contents: read id-token: write - pull-requests: write + statuses: write outputs: commit_status_description: ${{ steps.summary.outputs.commit_status_description }} webhook_payload: ${{ steps.summary.outputs.webhook_payload }} @@ -350,8 +332,7 @@ jobs: server-image: ${{ env.SERVER_IMAGE }} pr-number: ${{ inputs.pr_number }} ref-branch: ${{ inputs.ref_branch }} - context-name: ${{ inputs.context_name }} - post-pr-comment: 'true' + commit-status-context: ${{ inputs.context_name }} github-token: ${{ secrets.GITHUB_TOKEN }} - name: ci/publish-webhook if: inputs.enable_reporting && env.REPORT_WEBHOOK_URL != '' @@ -365,45 +346,3 @@ jobs: SUMMARY_OUTCOME: ${{ steps.summary.outcome }} run: | [ "$SUMMARY_OUTCOME" = "success" ] - - update-success-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - if: always() && needs.report.result == 'success' - needs: - - dispatch-begin - - report - steps: - - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: ${{ needs.report.outputs.commit_status_description }} - status: success - target_url: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - - update-failure-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - if: always() && needs.report.result != 'success' - needs: - - dispatch-begin - - report - steps: - - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: ${{ needs.report.outputs.commit_status_description }} - status: failure - target_url: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} diff --git a/.github/workflows/e2e-tests-playwright-template-v2.yml b/.github/workflows/e2e-tests-playwright-template-v2.yml index a008527473d..606332c36fc 100644 --- a/.github/workflows/e2e-tests-playwright-template-v2.yml +++ b/.github/workflows/e2e-tests-playwright-template-v2.yml @@ -96,35 +96,17 @@ permissions: contents: read statuses: write id-token: write - pull-requests: write env: SERVER_IMAGE: "${{ inputs.server_image_repo }}/${{ inputs.server_edition == 'fips' && 'mattermost-enterprise-fips-edition' || inputs.server_edition == 'team' && 'mattermost-team-edition' || 'mattermost-enterprise-edition' }}:${{ inputs.server_image_tag }}" jobs: - update-initial-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - steps: - - name: ci/set-initial-status - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: "tests running, image_tag:${{ inputs.server_image_tag }}${{ inputs.server_image_aliases && format(' ({0})', inputs.server_image_aliases) || '' }}" - status: pending - dispatch-begin: runs-on: ubuntu-24.04 permissions: contents: read id-token: write - pull-requests: write + statuses: write outputs: composite-identity-json: ${{ steps.composite-identity.outputs.composite-identity-json }} workers-matrix: ${{ steps.matrix.outputs.workers }} @@ -186,10 +168,10 @@ jobs: total-reports-expected: ${{ inputs.workers }} retest-on-fail: ${{ inputs.retest_on_fail }} playwright-project: ${{ inputs.playwright_project }} - post-pr-comment: 'true' github-token: ${{ secrets.GITHUB_TOKEN }} - context-name: ${{ inputs.context_name }} - server-image: ${{ env.SERVER_IMAGE }} + commit-status-context: ${{ inputs.context_name }} + image-tag: ${{ inputs.server_image_tag }} + image-aliases: ${{ inputs.server_image_aliases }} workers: name: dispatch-run-${{ matrix.worker_index }} @@ -286,7 +268,7 @@ jobs: permissions: contents: read id-token: write - pull-requests: write + statuses: write outputs: commit_status_description: ${{ steps.summary.outputs.commit_status_description }} webhook_payload: ${{ steps.summary.outputs.webhook_payload }} @@ -307,8 +289,7 @@ jobs: server-image: ${{ env.SERVER_IMAGE }} pr-number: ${{ inputs.pr_number }} ref-branch: ${{ inputs.ref_branch }} - context-name: ${{ inputs.context_name }} - post-pr-comment: 'true' + commit-status-context: ${{ inputs.context_name }} github-token: ${{ secrets.GITHUB_TOKEN }} - name: ci/publish-webhook if: inputs.enable_reporting && env.REPORT_WEBHOOK_URL != '' @@ -322,45 +303,3 @@ jobs: SUMMARY_OUTCOME: ${{ steps.summary.outcome }} run: | [ "$SUMMARY_OUTCOME" = "success" ] - - update-success-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - if: always() && needs.report.result == 'success' - needs: - - dispatch-begin - - report - steps: - - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: ${{ needs.report.outputs.commit_status_description }} - status: success - target_url: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - - update-failure-status: - runs-on: ubuntu-24.04 - permissions: - contents: read - statuses: write - if: always() && needs.report.result != 'success' - needs: - - dispatch-begin - - report - steps: - - uses: mattermost/actions/delivery/update-commit-status@f324ac89b05cc3511cb06e60642ac2fb829f0a63 - env: - GITHUB_TOKEN: ${{ github.token }} - with: - repository_full_name: ${{ github.repository }} - commit_sha: ${{ inputs.commit_sha }} - context: ${{ inputs.context_name }} - description: ${{ needs.report.outputs.commit_status_description }} - status: failure - target_url: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} From 238867e24762ab1557f676c589820615d7293d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 18 May 2026 12:16:32 +0200 Subject: [PATCH 17/80] MM-68732: Remove global mutex for login attempts in favour of database serialization (#36515) * Add atomic login-attempt counter primitives to UserStore Two new store methods back the upcoming switch from a global per-node mutex to per-user atomic slot claiming: TryIncrementFailedPasswordAttempts(userID, maxAttempts) (bool, error) UPDATE Users SET FailedAttempts = FailedAttempts + 1 WHERE Id = ? AND FailedAttempts < maxAttempts Returns true when a slot was claimed (rows affected == 1) and false when the cap was already reached. The conditional UPDATE serialises concurrent attempts on the same user via the row lock, so the cap is enforced without any application-level locking and without serialising attempts across users. DecrementFailedPasswordAttempts(userID) error UPDATE Users SET FailedAttempts = FailedAttempts - 1 WHERE Id = ? AND FailedAttempts > 0 Releases a slot previously claimed by TryIncrement when the in-flight authentication turns out not to be a credential failure. The conditional UPDATE means concurrent decrements cannot underflow. Storetest covers both primitives: claim-below-cap, reject-at-cap, reject-above-cap, no-op for unknown user, and a 50-goroutine concurrent test with a start barrier asserting exactly maxAttempts slots are ever claimed and that decrement clamps at zero under contention. The testify mock is regenerated here so the storetest package that returns *mocks.UserStore as a store.UserStore still satisfies the interface; the wrapper layers are regenerated in the next commit. ------ AI assisted commit * Regenerate store layers for the new primitives Pick up TryIncrementFailedPasswordAttempts and DecrementFailedPasswordAttempts in every generated wrapper: - retrylayer: retry on repeatable errors using the standard three-attempt loop. - timerlayer: record store-method duration metrics under UserStore.TryIncrementFailedPasswordAttempts and UserStore.DecrementFailedPasswordAttempts. - localcachelayer: invalidate the profile cache only after the underlying conditional UPDATE actually changes a row; an at-cap no-op return on TryIncrement no longer produces unnecessary cluster invalidation traffic. ------ AI assisted commit * Drop login-attempt mutex; use per-user slot claiming Replace the global per-node mutex that serialised every login attempt with the database-side atomic slot machine added on the Users row. Each of the three authentication entry points now pre-claims a slot via TryIncrementFailedPasswordAttempts before running the expensive password / LDAP / MFA check, and releases the slot when the failure path is not a real credential mismatch: - CheckPasswordAndAllCriteria (email/password): refunds the slot on backend errors during the password check (malformed stored hash, hasher misc failure, password-migration write failure) so a transient infra issue cannot ratchet FailedAttempts to a lockout for a user with valid credentials; refunds on the MFA pre-flight probe (empty mfaToken on an MFA-enabled user) so the probe is not counted as a real attempt. - DoubleCheckPassword: same backend-error refund predicate. - checkLdapUserPasswordAndAllCriteria: pre-claims only for existing users (first-time LDAP users have no local row to claim against); refunds non-credential DoLogin errors (server unreachable, transient) so an LDAP outage cannot lock out everyone; refunds the MFA pre-flight probe; for first-time users, explicitly bumps the counter via UpdateFailedPasswordAttempts on a real bad-password or bad-MFA attempt, matching the pre-refactor counting behaviour. If the refund itself fails the underlying authentication error is preserved and returned to the caller (the failure is logged); a leaked slot is annoying, but masking the real failure with a generic store 500 would be a clear observability regression. Cluster-wide behaviour also changes: the previous design honoured MaximumLoginAttempts per node, so an n-node cluster effectively permitted n * MaximumLoginAttempts attempts. The cap is now enforced globally. ------ AI assisted commit * Cover app-layer behaviors of the new login slot machine The store-layer tests already exercise TryIncrement and Decrement under concurrency and at the cap boundary. The new behavioural contracts at the app layer were not covered, so a regression that flipped a refund predicate, a probe condition, or a first-time LDAP path would have slipped through type checking and existing unit tests. Add tests around the three callers of the new path: - CheckPasswordAndAllCriteria: an MFA pre-flight probe (empty token) does not consume a slot; a real attempt with a wrong non-empty token does; a backend error during the password check (malformed stored hash) refunds the slot; the happy path also asserts FailedAttempts resets to zero. - DoubleCheckPassword: gets its first test coverage, covering the happy path, rate-limit rejection once max attempts is reached, and the backend-error refund path. - checkLdapUserPasswordAndAllCriteria: covers paths the table loop did not exercise, first-time LDAP user with a bad password (uses GetUserByAuth to reach the freshly created row), first-time LDAP user with a wrong MFA token, existing LDAP user with a non-credential DoLogin error (slot refunded), and the existing LDAP user MFA pre-flight probe (slot refunded). ------ AI assisted commit * Address coderabbit review ------ AI assisted commit * Fix race in first-time LDAP failed-attempt counter For first-time LDAP users we have no local row to pre-claim, so the bad-password and bad-MFA branches fell back to an absolute UpdateFailedPasswordAttempts(id, ldapUser.FailedAttempts+1) based on a snapshot from GetUserByAuth. Concurrent first-attempt requests for the same user could all read FailedAttempts == 0 and all write 1, losing increments. As a secondary issue the absolute set did not enforce MaximumLoginAttempts, so the counter could also drift past the cap. Switch both branches to TryIncrementFailedPasswordAttempts, the atomic conditional UPDATE already used on every other path. The row lock serialises concurrent increments and the predicate caps at MaximumLoginAttempts. A new concurrent storetest-style subtest runs 3 * maxFailedLoginAttempts goroutines through the first-time bad-password path against the same fresh LDAP row and asserts FailedAttempts lands at exactly maxFailedLoginAttempts. Against the previous absolute-set implementation the test fails (observed FailedAttempts = 4 with maxFailedLoginAttempts = 3, either a lost increment or a cap overshoot). The first-time bad-password branch also switches from a wrapped 500 return on store error to log-and-continue, matching the rest of the file's refund/probe error handling: the underlying LDAP authentication failure is the more useful error for the caller. ------ AI assisted commit * Address review comments ------ AI assisted commit --------- Co-authored-by: Mattermost Build --- server/channels/app/authentication.go | 139 +++++--- server/channels/app/authentication_test.go | 318 ++++++++++++++++-- server/channels/app/channels.go | 8 +- .../store/localcachelayer/user_layer.go | 19 ++ .../channels/store/retrylayer/retrylayer.go | 42 +++ server/channels/store/sqlstore/user_store.go | 39 +++ server/channels/store/store.go | 2 + .../store/storetest/mocks/UserStore.go | 46 +++ server/channels/store/storetest/user_store.go | 143 ++++++++ .../channels/store/timerlayer/timerlayer.go | 32 ++ 10 files changed, 715 insertions(+), 73 deletions(-) diff --git a/server/channels/app/authentication.go b/server/channels/app/authentication.go index edda80f5aac..73e7668f856 100644 --- a/server/channels/app/authentication.go +++ b/server/channels/app/authentication.go @@ -62,7 +62,7 @@ func (a *App) IsPasswordValid(rctx request.CTX, password string) *model.AppError return nil } -func (a *App) checkUserPassword(user *model.User, password string, invalidateCache bool) *model.AppError { +func (a *App) checkUserPassword(user *model.User, password string) *model.AppError { if user.Password == "" || password == "" { return model.NewAppError("checkUserPassword", "api.user.check_user_password.invalid.app_error", nil, "user_id="+user.Id, http.StatusUnauthorized) } @@ -76,16 +76,6 @@ func (a *App) checkUserPassword(user *model.User, password string, invalidateCac // Compare the password using the hasher that generated it err = hasher.CompareHashAndPassword(phc, password) if err != nil && errors.Is(err, hashers.ErrMismatchedHashAndPassword) { - // Increment the number of failed password attempts in case of - // mismatched hash and password - if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(user.Id, user.FailedAttempts+1); passErr != nil { - return model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) - } - - if invalidateCache { - a.InvalidateCacheForUser(user.Id) - } - return model.NewAppError("checkUserPassword", "api.user.check_user_password.invalid.app_error", nil, "user_id="+user.Id, http.StatusUnauthorized).Wrap(err) } else if err != nil { return model.NewAppError("checkUserPassword", "app.valid_password_generic.app_error", nil, "", http.StatusInternalServerError).Wrap(err) @@ -118,11 +108,6 @@ func (a *App) migratePassword(user *model.User, password string) *model.AppError } func (a *App) CheckPasswordAndAllCriteria(rctx request.CTX, userID string, password string, mfaToken string) *model.AppError { - // MM-37585 - // Use locks to avoid concurrently checking AND updating the failed login attempts. - a.ch.emailLoginAttemptsMut.Lock() - defer a.ch.emailLoginAttemptsMut.Unlock() - user, err := a.GetUser(userID) if err != nil { if err.Id != MissingAccountError { @@ -137,16 +122,36 @@ func (a *App) CheckPasswordAndAllCriteria(rctx request.CTX, userID string, passw return err } - if err := a.checkUserPassword(user, password, false); err != nil { + maxAttempts := *a.Config().ServiceSettings.MaximumLoginAttempts + claimed, claimErr := a.Srv().Store().User().TryIncrementFailedPasswordAttempts(user.Id, maxAttempts) + if claimErr != nil { + return model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(claimErr) + } + if !claimed { + return model.NewAppError("checkUserLoginAttempts", "api.user.check_user_login_attempts.too_many.app_error", nil, "user_id="+user.Id, http.StatusUnauthorized) + } + + if err := a.checkUserPassword(user, password); err != nil { + // Only keep the claimed slot when the failure is an actual + // credential mismatch; backend errors (hasher failures, migration + // failures, malformed stored hash) must not consume a slot or a + // transient infra issue could lock out a user with valid creds. + if err.Id != "api.user.check_user_password.invalid.app_error" { + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(user.Id); passErr != nil { + rctx.Logger().Warn("failed to refund login attempt slot", mlog.String("user_id", user.Id), mlog.Err(passErr)) + } + } return err } if err := a.CheckUserMfa(rctx, user, mfaToken); err != nil { - // If the mfaToken is not set, we assume the client used this as a pre-flight request to query the server - // about the MFA state of the user in question - if mfaToken != "" { - if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(user.Id, user.FailedAttempts+1); passErr != nil { - return model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) + // The slot we claimed already counts this as a failed attempt; + // the only special case is when no mfaToken was provided, which + // is treated as a pre-flight MFA-state probe rather than a real + // attempt — refund the slot so the probe is not counted. + if mfaToken == "" { + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(user.Id); passErr != nil { + rctx.Logger().Warn("failed to refund MFA probe slot", mlog.String("user_id", user.Id), mlog.Err(passErr)) } } @@ -166,11 +171,21 @@ func (a *App) CheckPasswordAndAllCriteria(rctx request.CTX, userID string, passw // This to be used for places we check the users password when they are already logged in func (a *App) DoubleCheckPassword(rctx request.CTX, user *model.User, password string) *model.AppError { - if err := checkUserLoginAttempts(user, *a.Config().ServiceSettings.MaximumLoginAttempts); err != nil { - return err + maxAttempts := *a.Config().ServiceSettings.MaximumLoginAttempts + claimed, claimErr := a.Srv().Store().User().TryIncrementFailedPasswordAttempts(user.Id, maxAttempts) + if claimErr != nil { + return model.NewAppError("DoubleCheckPassword", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(claimErr) + } + if !claimed { + return model.NewAppError("checkUserLoginAttempts", "api.user.check_user_login_attempts.too_many.app_error", nil, "user_id="+user.Id, http.StatusUnauthorized) } - if err := a.checkUserPassword(user, password, true); err != nil { + if err := a.checkUserPassword(user, password); err != nil { + if err.Id != "api.user.check_user_password.invalid.app_error" { + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(user.Id); passErr != nil { + rctx.Logger().Warn("failed to refund login attempt slot", mlog.String("user_id", user.Id), mlog.Err(passErr)) + } + } return err } @@ -184,11 +199,7 @@ func (a *App) DoubleCheckPassword(rctx request.CTX, user *model.User, password s } func (a *App) checkLdapUserPasswordAndAllCriteria(rctx request.CTX, user *model.User, password, mfaToken string) (*model.User, *model.AppError) { - // MM-37585: Use locks to avoid concurrently checking AND updating the failed login attempts. - a.ch.ldapLoginAttemptsMut.Lock() - defer a.ch.ldapLoginAttemptsMut.Unlock() - - // We need to get the latest value of the user from the database after we acquire the lock. user is nil for first-time LDAP users. + // We need to get the latest value of the user from the database. user.Id is empty for first-time LDAP users. if user.Id != "" { var err *model.AppError user, err = a.GetUser(user.Id) @@ -209,10 +220,16 @@ func (a *App) checkLdapUserPasswordAndAllCriteria(rctx request.CTX, user *model. return nil, err } - // First time LDAP users will not have a userID + maxAttempts := *a.Config().LdapSettings.MaximumLoginAttempts + + // First-time LDAP users have no local row yet to pre-claim against. if user.Id != "" { - if err := checkUserLoginAttempts(user, *a.Config().LdapSettings.MaximumLoginAttempts); err != nil { - return nil, err + claimed, claimErr := a.Srv().Store().User().TryIncrementFailedPasswordAttempts(user.Id, maxAttempts) + if claimErr != nil { + return nil, model.NewAppError("checkLdapUserPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(claimErr) + } + if !claimed { + return nil, model.NewAppError("checkUserLoginAttempts", "api.user.check_user_login_attempts.too_many_ldap.app_error", nil, "user_id="+user.Id, http.StatusUnauthorized) } } @@ -233,8 +250,23 @@ func (a *App) checkLdapUserPasswordAndAllCriteria(rctx request.CTX, user *model. if err.Id == "ent.ldap.do_login.invalid_password.app_error" { rctx.Logger().LogM(mlog.MlvlLDAPInfo, "A user tried to sign in, which matched an LDAP account, but the password was incorrect.", mlog.String("ldap_id", *ldapID)) - if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(ldapUser.Id, ldapUser.FailedAttempts+1); passErr != nil { - return nil, model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) + // For existing users we already claimed the slot above, so the + // counter has already been bumped. For first-time users (the + // row was just created by DoLogin) we still need to count the + // failed attempt explicitly, using the atomic primitive so + // concurrent first-attempt requests cannot overwrite each + // other's increments. + if user.Id == "" { + if _, passErr := a.Srv().Store().User().TryIncrementFailedPasswordAttempts(ldapUser.Id, maxAttempts); passErr != nil { + rctx.Logger().Warn("failed to record failed attempt for first-time LDAP user", mlog.String("user_id", ldapUser.Id), mlog.Err(passErr)) + } + } + } else if user.Id != "" { + // Non-credential failure (LDAP unreachable, transient error, + // etc.) on an existing user must not consume the slot we + // pre-claimed, or an LDAP outage could lock out everyone. + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(user.Id); passErr != nil { + rctx.Logger().Warn("failed to refund LDAP login attempt slot", mlog.String("user_id", user.Id), mlog.Err(passErr)) } } @@ -243,24 +275,43 @@ func (a *App) checkLdapUserPasswordAndAllCriteria(rctx request.CTX, user *model. } if err = a.CheckUserMfa(rctx, ldapUser, mfaToken); err != nil { - // If the mfaToken is not set, we assume the client used this as a pre-flight request to query the server - // about the MFA state of the user in question - if mfaToken != "" && ldapUser.Id != "" { - if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(ldapUser.Id, ldapUser.FailedAttempts+1); passErr != nil { - return nil, model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) + // For existing LDAP users we pre-claimed a slot, so it already + // counts as a failed attempt. The only special case is when no + // mfaToken was provided, which is treated as a pre-flight + // MFA-state probe rather than a real attempt — refund the slot + // so the probe is not counted. + // + // For first-time LDAP users we did not pre-claim (no row to + // claim against), so a real MFA attempt with a non-empty token + // still needs to be counted explicitly against the freshly + // created row. + switch { + case user.Id == "" && mfaToken != "": + if _, passErr := a.Srv().Store().User().TryIncrementFailedPasswordAttempts(ldapUser.Id, maxAttempts); passErr != nil { + rctx.Logger().Warn("failed to record failed MFA attempt for first-time LDAP user", mlog.String("user_id", ldapUser.Id), mlog.Err(passErr)) + } + case user.Id != "" && mfaToken == "": + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(ldapUser.Id); passErr != nil { + rctx.Logger().Warn("failed to refund LDAP MFA probe slot", mlog.String("user_id", ldapUser.Id), mlog.Err(passErr)) } } return nil, err } if err = checkUserNotDisabled(ldapUser); err != nil { + // Existing LDAP users had a slot pre-claimed; a disabled-account + // rejection is not a credential failure, so refund the slot so a + // reactivated user is not immediately rate-limited. + if user.Id != "" { + if passErr := a.Srv().Store().User().DecrementFailedPasswordAttempts(ldapUser.Id); passErr != nil { + rctx.Logger().Warn("failed to refund disabled LDAP login attempt slot", mlog.String("user_id", ldapUser.Id), mlog.Err(passErr)) + } + } return nil, err } - if ldapUser.FailedAttempts > 0 { - if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(ldapUser.Id, 0); passErr != nil { - return nil, model.NewAppError("CheckPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) - } + if passErr := a.Srv().Store().User().UpdateFailedPasswordAttempts(ldapUser.Id, 0); passErr != nil { + return nil, model.NewAppError("checkLdapUserPasswordAndAllCriteria", "app.user.update_failed_pwd_attempts.app_error", nil, "", http.StatusInternalServerError).Wrap(passErr) } // user successfully authenticated diff --git a/server/channels/app/authentication_test.go b/server/channels/app/authentication_test.go index 4e718b58d80..a581af0e4e9 100644 --- a/server/channels/app/authentication_test.go +++ b/server/channels/app/authentication_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" + "golang.org/x/sync/errgroup" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/v8/channels/app/password/hashers" @@ -96,6 +97,63 @@ func TestCheckPasswordAndAllCriteria(t *testing.T) { appErr = th.App.CheckPasswordAndAllCriteria(th.Context, th.BasicUser.Id, password, token) require.Nil(t, appErr) + + updatedUser, appErr := th.App.GetUser(th.BasicUser.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "successful login must reset FailedAttempts") + }) + + t.Run("MFA pre-flight probe does not consume a slot", func(t *testing.T) { + // An empty mfaToken on an MFA-enabled user is a pre-flight probe + // (the client is checking whether MFA is required) and must not + // count as a failed attempt. + err := th.App.Srv().Store().User().UpdateFailedPasswordAttempts(th.BasicUser.Id, 0) + require.NoError(t, err) + + appErr := th.App.CheckPasswordAndAllCriteria(th.Context, th.BasicUser.Id, password, "") + require.NotNil(t, appErr) + require.Equal(t, "mfa.validate_token.authenticate.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(th.BasicUser.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "MFA probe must not consume a slot") + }) + + t.Run("MFA real attempt with a wrong token consumes a slot", func(t *testing.T) { + // A non-empty bad mfaToken is a real attempt, not a probe; the + // slot the pre-claim consumed stays consumed. + err := th.App.Srv().Store().User().UpdateFailedPasswordAttempts(th.BasicUser.Id, 0) + require.NoError(t, err) + + appErr := th.App.CheckPasswordAndAllCriteria(th.Context, th.BasicUser.Id, password, "123456") + require.NotNil(t, appErr) + require.Equal(t, "api.user.check_user_mfa.bad_code.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(th.BasicUser.Id) + require.Nil(t, appErr) + require.Equal(t, 1, updatedUser.FailedAttempts, "real MFA failure must consume a slot") + }) + + t.Run("backend error refunds the slot", func(t *testing.T) { + // Backend errors during the password check (malformed stored hash, + // hasher misc failure, migration failure) must not consume a slot + // or a transient infra issue could lock out a user with valid + // credentials. We trigger this via an unparseable PHC string, + // which surfaces as invalid_hash.app_error. + badHashUser := th.CreateUser(t) + err := th.Server.Store().User().UpdatePassword(badHashUser.Id, "$pbkdf2$bogus") + require.NoError(t, err) + th.App.InvalidateCacheForUser(badHashUser.Id) + err = th.App.Srv().Store().User().UpdateFailedPasswordAttempts(badHashUser.Id, 0) + require.NoError(t, err) + + appErr := th.App.CheckPasswordAndAllCriteria(th.Context, badHashUser.Id, "any-password", "") + require.NotNil(t, appErr) + require.Equal(t, "api.user.check_user_password.invalid_hash.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(badHashUser.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "backend error must not consume a slot") }) t.Run("validate concurrent failed attempts to bypass checks", func(t *testing.T) { @@ -159,6 +217,66 @@ func TestCheckPasswordAndAllCriteria(t *testing.T) { }) } +func TestDoubleCheckPassword(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + const maxFailedLoginAttempts = 3 + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.MaximumLoginAttempts = maxFailedLoginAttempts + }) + + password := model.NewTestPassword() + appErr := th.App.UpdatePassword(th.Context, th.BasicUser, password) + require.Nil(t, appErr) + + // DoubleCheckPassword does not re-fetch the user; it inspects user.Password + // directly. Pull a fresh struct that reflects the hash we just wrote. + user, appErr := th.App.GetUser(th.BasicUser.Id) + require.Nil(t, appErr) + + t.Run("correct password succeeds and resets the counter", func(t *testing.T) { + err := th.App.Srv().Store().User().UpdateFailedPasswordAttempts(user.Id, maxFailedLoginAttempts-1) + require.NoError(t, err) + + appErr := th.App.DoubleCheckPassword(th.Context, user, password) + require.Nil(t, appErr) + + updatedUser, appErr := th.App.GetUser(user.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts) + }) + + t.Run("rate limit is enforced once max attempts is reached", func(t *testing.T) { + err := th.App.Srv().Store().User().UpdateFailedPasswordAttempts(user.Id, maxFailedLoginAttempts) + require.NoError(t, err) + + appErr := th.App.DoubleCheckPassword(th.Context, user, password) + require.NotNil(t, appErr) + require.Equal(t, "api.user.check_user_login_attempts.too_many.app_error", appErr.Id) + }) + + t.Run("backend error refunds the slot", func(t *testing.T) { + badHashUser := th.CreateUser(t) + err := th.Server.Store().User().UpdatePassword(badHashUser.Id, "$pbkdf2$bogus") + require.NoError(t, err) + th.App.InvalidateCacheForUser(badHashUser.Id) + err = th.App.Srv().Store().User().UpdateFailedPasswordAttempts(badHashUser.Id, 0) + require.NoError(t, err) + + user, appErr := th.App.GetUser(badHashUser.Id) + require.Nil(t, appErr) + + appErr = th.App.DoubleCheckPassword(th.Context, user, "any-password") + require.NotNil(t, appErr) + require.Equal(t, "api.user.check_user_password.invalid_hash.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(badHashUser.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "backend error must not consume a slot") + }) +} + func TestCheckLdapUserPasswordAndAllCriteria(t *testing.T) { th := SetupEnterprise(t).InitBasic(t) @@ -256,6 +374,172 @@ func TestCheckLdapUserPasswordAndAllCriteria(t *testing.T) { } }) } + + // The cases below cover paths the table loop above does not exercise: + // first-time LDAP users (user.Id == ""), LDAP backend errors that are + // not credential failures, and the MFA pre-flight probe refund. Each + // subtest builds its own mockLdap so expectations from previous + // subtests cannot match the wrong call. + + createLdapUserWithMFA := func(t *testing.T, emailLocal string) (*model.User, *string) { + t.Helper() + userAuthData := model.NewRandomString(32) + created, appErr := th.App.CreateUser(th.Context, &model.User{ + Email: emailLocal + "@mattermost-customer.com", + Username: emailLocal, + AuthService: model.UserAuthServiceLdap, + AuthData: &userAuthData, + EmailVerified: true, + }) + require.Nil(t, appErr) + secret, appErr := th.App.GenerateMfaSecret(created.Id) + require.Nil(t, appErr) + require.NoError(t, th.Server.Store().User().UpdateMfaActive(created.Id, true)) + require.NoError(t, th.Server.Store().User().UpdateMfaSecret(created.Id, secret.Secret)) + require.NoError(t, th.App.Srv().Store().User().UpdateFailedPasswordAttempts(created.Id, 0)) + created, appErr = th.App.GetUser(created.Id) + require.Nil(t, appErr) + created.AuthData = &userAuthData + return created, &userAuthData + } + + t.Run("first-time LDAP user with wrong password increments counter", func(t *testing.T) { + // DoLogin in production creates the row before reporting a bad + // password; we pre-create it here so GetUserByAuth can resolve it. + firstAuthData := model.NewRandomString(32) + preCreated, appErr := th.App.CreateUser(th.Context, &model.User{ + Email: "ldapuser-first-bad-pwd@mattermost-customer.com", + Username: "ldapuser-first-bad-pwd", + AuthService: model.UserAuthServiceLdap, + AuthData: &firstAuthData, + EmailVerified: true, + }) + require.Nil(t, appErr) + require.NoError(t, th.App.Srv().Store().User().UpdateFailedPasswordAttempts(preCreated.Id, 0)) + + freshMock := &mocks.LdapInterface{} + th.App.Channels().Ldap = freshMock + t.Cleanup(func() { th.App.Channels().Ldap = mockLdap }) + freshMock.Mock.On("DoLogin", th.Context, firstAuthData, wrongPassword).Return(nil, &model.AppError{Id: "ent.ldap.do_login.invalid_password.app_error"}) + + _, appErr = th.App.checkLdapUserPasswordAndAllCriteria(th.Context, &model.User{ + AuthService: model.UserAuthServiceLdap, + AuthData: &firstAuthData, + }, wrongPassword, "") + require.NotNil(t, appErr) + require.Equal(t, "ent.ldap.do_login.invalid_password.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(preCreated.Id) + require.Nil(t, appErr) + require.Equal(t, 1, updatedUser.FailedAttempts, "first-time LDAP wrong password must be counted") + }) + + t.Run("first-time LDAP user with wrong MFA token increments counter", func(t *testing.T) { + // DoLogin returns the freshly created user struct; the function + // then calls CheckUserMfa, which fails on a wrong non-empty token. + preCreated, authDataPtr := createLdapUserWithMFA(t, "ldapuser-first-bad-mfa") + + freshMock := &mocks.LdapInterface{} + th.App.Channels().Ldap = freshMock + t.Cleanup(func() { th.App.Channels().Ldap = mockLdap }) + freshMock.Mock.On("DoLogin", th.Context, *authDataPtr, validPassword).Return(preCreated, nil) + + _, appErr := th.App.checkLdapUserPasswordAndAllCriteria(th.Context, &model.User{ + AuthService: model.UserAuthServiceLdap, + AuthData: authDataPtr, + }, validPassword, "123456") + require.NotNil(t, appErr) + require.Equal(t, "api.user.check_user_mfa.bad_code.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(preCreated.Id) + require.Nil(t, appErr) + require.Equal(t, 1, updatedUser.FailedAttempts, "first-time LDAP wrong MFA must be counted") + }) + + t.Run("existing LDAP user with LDAP backend error refunds the slot", func(t *testing.T) { + // A non-credential LDAP error (server unreachable, transient + // failure) on an existing user must not consume the pre-claimed + // slot, or an LDAP outage could lock out everyone. + require.NoError(t, th.App.Srv().Store().User().UpdateFailedPasswordAttempts(user.Id, 0)) + + freshMock := &mocks.LdapInterface{} + th.App.Channels().Ldap = freshMock + t.Cleanup(func() { th.App.Channels().Ldap = mockLdap }) + freshMock.Mock.On("DoLogin", th.Context, authData, wrongPassword).Return(nil, &model.AppError{Id: "ent.ldap.do_login.unable_to_connect.app_error"}) + + _, appErr := th.App.checkLdapUserPasswordAndAllCriteria(th.Context, user, wrongPassword, "") + require.NotNil(t, appErr) + require.Equal(t, "ent.ldap.do_login.unable_to_connect.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(user.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "LDAP backend error must refund the slot") + }) + + t.Run("existing LDAP user MFA pre-flight probe refunds the slot", func(t *testing.T) { + // Empty mfaToken on an MFA-enabled LDAP user is a probe; the slot + // the pre-claim consumed is refunded. + preCreated, authDataPtr := createLdapUserWithMFA(t, "ldapuser-existing-mfa-probe") + + freshMock := &mocks.LdapInterface{} + th.App.Channels().Ldap = freshMock + t.Cleanup(func() { th.App.Channels().Ldap = mockLdap }) + freshMock.Mock.On("DoLogin", th.Context, *authDataPtr, validPassword).Return(preCreated, nil) + + _, appErr := th.App.checkLdapUserPasswordAndAllCriteria(th.Context, preCreated, validPassword, "") + require.NotNil(t, appErr) + require.Equal(t, "mfa.validate_token.authenticate.app_error", appErr.Id) + + updatedUser, appErr := th.App.GetUser(preCreated.Id) + require.Nil(t, appErr) + require.Equal(t, 0, updatedUser.FailedAttempts, "MFA probe on existing LDAP user must not consume a slot") + }) + + t.Run("concurrent first-time LDAP wrong password caps at maxAttempts", func(t *testing.T) { + // A first-time LDAP user has no local row yet, so the slot is + // not pre-claimed. The fallback counter bump must use the atomic + // TryIncrement primitive: a previous implementation used an + // absolute UPDATE Users SET FailedAttempts = ldapUser.FailedAttempts + 1 + // based on an in-memory snapshot, which lost increments when + // concurrent first-attempt requests all read FailedAttempts = 0 + // and all wrote 1. Under the atomic primitive the counter caps + // at maxFailedLoginAttempts regardless of contention. + concurrentAuthData := model.NewRandomString(32) + preCreated, appErr := th.App.CreateUser(th.Context, &model.User{ + Email: "ldapuser-first-bad-pwd-conc@mattermost-customer.com", + Username: "ldapuser-first-bad-pwd-conc", + AuthService: model.UserAuthServiceLdap, + AuthData: &concurrentAuthData, + EmailVerified: true, + }) + require.Nil(t, appErr) + require.NoError(t, th.App.Srv().Store().User().UpdateFailedPasswordAttempts(preCreated.Id, 0)) + + freshMock := &mocks.LdapInterface{} + th.App.Channels().Ldap = freshMock + t.Cleanup(func() { th.App.Channels().Ldap = mockLdap }) + freshMock.Mock.On("DoLogin", th.Context, concurrentAuthData, wrongPassword).Return(nil, &model.AppError{Id: "ent.ldap.do_login.invalid_password.app_error"}) + + const goroutines = maxFailedLoginAttempts * 3 + var g errgroup.Group + start := make(chan struct{}) + for range goroutines { + g.Go(func() error { + <-start + _, _ = th.App.checkLdapUserPasswordAndAllCriteria(th.Context, &model.User{ + AuthService: model.UserAuthServiceLdap, + AuthData: &concurrentAuthData, + }, wrongPassword, "") + return nil + }) + } + close(start) + require.NoError(t, g.Wait()) + + updatedUser, appErr := th.App.GetUser(preCreated.Id) + require.Nil(t, appErr) + require.Equal(t, maxFailedLoginAttempts, updatedUser.FailedAttempts, "concurrent first-time attempts must not lose increments and must cap at maxAttempts") + }) } func TestCheckLdapUserPasswordConcurrency(t *testing.T) { @@ -400,13 +684,7 @@ func TestCheckUserPassword(t *testing.T) { t.Run("valid password with current hashing", func(t *testing.T) { user := createUserWithHash(pwdPBKDF2) - err := th.App.checkUserPassword(user, pwd, false) - require.Nil(t, err) - }) - - t.Run("valid password with current hashing and cache invalidation", func(t *testing.T) { - user := createUserWithHash(pwdPBKDF2) - err := th.App.checkUserPassword(user, pwd, true) + err := th.App.checkUserPassword(user, pwd) require.Nil(t, err) }) @@ -415,13 +693,9 @@ func TestCheckUserPassword(t *testing.T) { t.Run("invalid password", func(t *testing.T) { user := createUserWithHash(pwdPBKDF2) - err := th.App.checkUserPassword(user, wrongPassword, false) + err := th.App.checkUserPassword(user, wrongPassword) require.NotNil(t, err) require.Equal(t, "api.user.check_user_password.invalid.app_error", err.Id) - - updatedUser, err := th.App.GetUser(user.Id) - require.Nil(t, err) - require.Equal(t, user.FailedAttempts+1, updatedUser.FailedAttempts) }) t.Run("password migration from outdated hash", func(t *testing.T) { @@ -429,7 +703,7 @@ func TestCheckUserPassword(t *testing.T) { require.Contains(t, user.Password, "$2a$10") require.NotContains(t, user.Password, "pbkdf2") - err := th.App.checkUserPassword(user, pwd, false) + err := th.App.checkUserPassword(user, pwd) require.Nil(t, err) updatedUser, err := th.App.GetUser(user.Id) @@ -438,20 +712,16 @@ func TestCheckUserPassword(t *testing.T) { require.Contains(t, updatedUser.Password, "$pbkdf2") // Re-check with updated password - err = th.App.checkUserPassword(user, pwd, false) + err = th.App.checkUserPassword(updatedUser, pwd) require.Nil(t, err) }) t.Run("password migration fails with invalid password", func(t *testing.T) { user := createUserWithHash(pwdBcrypt) - err := th.App.checkUserPassword(user, wrongPassword, false) + err := th.App.checkUserPassword(user, wrongPassword) require.NotNil(t, err) require.Equal(t, "api.user.check_user_password.invalid.app_error", err.Id) - - updatedUser, err := th.App.GetUser(user.Id) - require.Nil(t, err) - require.Equal(t, user.FailedAttempts+1, updatedUser.FailedAttempts) }) t.Run("empty password", func(t *testing.T) { @@ -460,7 +730,7 @@ func TestCheckUserPassword(t *testing.T) { user, err := th.App.GetUser(user.Id) require.Nil(t, err) - err = th.App.checkUserPassword(user, "", false) + err = th.App.checkUserPassword(user, "") require.NotNil(t, err) require.Equal(t, "api.user.check_user_password.invalid.app_error", err.Id) }) @@ -471,7 +741,7 @@ func TestCheckUserPassword(t *testing.T) { user, err := th.App.GetUser(user.Id) require.Nil(t, err) - err = th.App.checkUserPassword(user, pwd, false) + err = th.App.checkUserPassword(user, pwd) require.NotNil(t, err) require.Equal(t, "api.user.check_user_password.invalid.app_error", err.Id) }) @@ -489,7 +759,7 @@ func TestCheckUserPassword(t *testing.T) { // The user hash contains the old parameter require.Contains(t, user.Password, "w=10000") - appErr := th.App.checkUserPassword(user, pwd, false) + appErr := th.App.checkUserPassword(user, pwd) require.Nil(t, appErr) updatedUser, appErr := th.App.GetUser(user.Id) @@ -500,7 +770,7 @@ func TestCheckUserPassword(t *testing.T) { require.NotContains(t, updatedUser.Password, "w=10000") // Re-check with updated password - appErr = th.App.checkUserPassword(user, pwd, false) + appErr = th.App.checkUserPassword(updatedUser, pwd) require.Nil(t, appErr) }) } @@ -542,7 +812,7 @@ func TestMigratePassword(t *testing.T) { require.Contains(t, updatedUser.Password, "$pbkdf2") // Re-check with updated password - err = th.App.checkUserPassword(user, pwd, false) + err = th.App.checkUserPassword(updatedUser, pwd) require.Nil(t, err) }) } diff --git a/server/channels/app/channels.go b/server/channels/app/channels.go index df3e90c2e85..eaa21d4ccd3 100644 --- a/server/channels/app/channels.go +++ b/server/channels/app/channels.go @@ -92,11 +92,9 @@ type Channels struct { postReminderMut sync.Mutex postReminderTask *model.ScheduledTask - interruptQuitChan chan struct{} - scheduledPostMut sync.Mutex - scheduledPostTask *model.ScheduledTask - emailLoginAttemptsMut sync.Mutex - ldapLoginAttemptsMut sync.Mutex + interruptQuitChan chan struct{} + scheduledPostMut sync.Mutex + scheduledPostTask *model.ScheduledTask } func NewChannels(s *Server) (*Channels, error) { diff --git a/server/channels/store/localcachelayer/user_layer.go b/server/channels/store/localcachelayer/user_layer.go index f9c6c182a61..4132007705e 100644 --- a/server/channels/store/localcachelayer/user_layer.go +++ b/server/channels/store/localcachelayer/user_layer.go @@ -222,6 +222,25 @@ func (s *LocalCacheUserStore) UpdateFailedPasswordAttempts(userID string, attemp return s.UserStore.UpdateFailedPasswordAttempts(userID, attempts) } +func (s *LocalCacheUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { + claimed, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) + if err != nil { + return false, err + } + if claimed { + s.InvalidateProfileCacheForUser(userID) + } + return claimed, nil +} + +func (s *LocalCacheUserStore) DecrementFailedPasswordAttempts(userID string) error { + if err := s.UserStore.DecrementFailedPasswordAttempts(userID); err != nil { + return err + } + s.InvalidateProfileCacheForUser(userID) + return nil +} + // Get is a cache wrapper around the SqlStore method to get a user profile by id. // It checks if the user entry is present in the cache, returning the entry from cache // if it is present. Otherwise, it fetches the entry from the store and stores it in the diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 938cd91d1c7..18e3ca33dc4 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -17413,6 +17413,48 @@ func (s *RetryLayerUserStore) UpdateFailedPasswordAttempts(userID string, attemp } +func (s *RetryLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { + + tries := 0 + for { + result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { + + tries := 0 + for { + err := s.UserStore.DecrementFailedPasswordAttempts(userID) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerUserStore) UpdateLastLogin(userID string, lastLogin int64) error { tries := 0 diff --git a/server/channels/store/sqlstore/user_store.go b/server/channels/store/sqlstore/user_store.go index 524bc917373..cfdfb2c4d3a 100644 --- a/server/channels/store/sqlstore/user_store.go +++ b/server/channels/store/sqlstore/user_store.go @@ -426,6 +426,45 @@ func (us SqlUserStore) UpdateFailedPasswordAttempts(userId string, attempts int) return nil } +// TryIncrementFailedPasswordAttempts atomically increments FailedAttempts by one +// for the given user, only if FailedAttempts is strictly less than maxAttempts. +// Returns true if the row was updated (a slot was claimed), false if the cap had +// already been reached (or the user does not exist). The row lock taken by the +// UPDATE serializes concurrent attempts on the same user, so the cap predicate +// is enforced without any application-level locking. +func (us SqlUserStore) TryIncrementFailedPasswordAttempts(userId string, maxAttempts int) (bool, error) { + res, err := us.GetMaster().Exec( + "UPDATE Users SET FailedAttempts = FailedAttempts + 1 WHERE Id = ? AND FailedAttempts < ?", + userId, maxAttempts, + ) + if err != nil { + return false, errors.Wrapf(err, "failed to update User with userId=%s", userId) + } + + rows, err := res.RowsAffected() + if err != nil { + return false, errors.Wrapf(err, "failed to read rows affected for userId=%s", userId) + } + + return rows == 1, nil +} + +// DecrementFailedPasswordAttempts atomically decrements FailedAttempts by one +// for the given user, only if FailedAttempts is strictly greater than zero. It +// is used to refund a slot previously claimed by TryIncrementFailedPasswordAttempts +// when the in-flight authentication turns out not to be a credential-failure +// event (e.g. a backend error or an MFA pre-flight probe). +func (us SqlUserStore) DecrementFailedPasswordAttempts(userId string) error { + _, err := us.GetMaster().Exec( + "UPDATE Users SET FailedAttempts = FailedAttempts - 1 WHERE Id = ? AND FailedAttempts > 0", + userId, + ) + if err != nil { + return errors.Wrapf(err, "failed to update User with userId=%s", userId) + } + return nil +} + func (us SqlUserStore) UpdateAuthData(userId string, service string, authData *string, email string, resetMfa bool) (string, error) { updateAt := model.GetMillis() diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 596af20bbeb..c93eb8a2725 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -482,6 +482,8 @@ type UserStore interface { GetEtagForAllProfiles() string GetEtagForProfiles(teamID string) string UpdateFailedPasswordAttempts(userID string, attempts int) error + TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) + DecrementFailedPasswordAttempts(userID string) error GetSystemAdminProfiles() (map[string]*model.User, error) PermanentDelete(rctx request.CTX, userID string) error AnalyticsActiveCount(timestamp int64, options model.UserCountOptions) (int64, error) diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index 12026e7f4b6..dda8cf0571b 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -2156,6 +2156,52 @@ func (_m *UserStore) UpdateFailedPasswordAttempts(userID string, attempts int) e return r0 } +// TryIncrementFailedPasswordAttempts provides a mock function with given fields: userID, maxAttempts +func (_m *UserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { + ret := _m.Called(userID, maxAttempts) + + if len(ret) == 0 { + panic("no return value specified for TryIncrementFailedPasswordAttempts") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, int) (bool, error)); ok { + return rf(userID, maxAttempts) + } + if rf, ok := ret.Get(0).(func(string, int) bool); ok { + r0 = rf(userID, maxAttempts) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, int) error); ok { + r1 = rf(userID, maxAttempts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DecrementFailedPasswordAttempts provides a mock function with given fields: userID +func (_m *UserStore) DecrementFailedPasswordAttempts(userID string) error { + ret := _m.Called(userID) + + if len(ret) == 0 { + panic("no return value specified for DecrementFailedPasswordAttempts") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(userID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateLastLogin provides a mock function with given fields: userID, lastLogin func (_m *UserStore) UpdateLastLogin(userID string, lastLogin int64) error { ret := _m.Called(userID, lastLogin) diff --git a/server/channels/store/storetest/user_store.go b/server/channels/store/storetest/user_store.go index da4ef47f1a2..7520267b39f 100644 --- a/server/channels/store/storetest/user_store.go +++ b/server/channels/store/storetest/user_store.go @@ -9,11 +9,13 @@ import ( "fmt" "sort" "strings" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/request" @@ -54,6 +56,8 @@ func TestUserStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) { t.Run("Update", func(t *testing.T) { testUserStoreUpdate(t, rctx, ss) }) t.Run("UpdateUpdateAt", func(t *testing.T) { testUserStoreUpdateUpdateAt(t, rctx, ss) }) t.Run("UpdateFailedPasswordAttempts", func(t *testing.T) { testUserStoreUpdateFailedPasswordAttempts(t, rctx, ss) }) + t.Run("TryIncrementFailedPasswordAttempts", func(t *testing.T) { testUserStoreTryIncrementFailedPasswordAttempts(t, rctx, ss) }) + t.Run("DecrementFailedPasswordAttempts", func(t *testing.T) { testUserStoreDecrementFailedPasswordAttempts(t, rctx, ss) }) t.Run("Get", func(t *testing.T) { testUserStoreGet(t, rctx, ss) }) t.Run("GetAllUsingAuthService", func(t *testing.T) { testGetAllUsingAuthService(t, rctx, ss) }) t.Run("GetAllProfiles", func(t *testing.T) { testUserStoreGetAllProfiles(t, rctx, ss) }) @@ -350,6 +354,145 @@ func testUserStoreUpdateFailedPasswordAttempts(t *testing.T, rctx request.CTX, s require.Equal(t, 3, user.FailedAttempts, "FailedAttempts not updated correctly") } +func testUserStoreTryIncrementFailedPasswordAttempts(t *testing.T, rctx request.CTX, ss store.Store) { + u1 := &model.User{} + u1.Email = MakeEmail() + _, err := ss.User().Save(rctx, u1) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u1.Id)) }() + _, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: model.NewId(), UserId: u1.Id}, -1) + require.NoError(t, nErr) + + const maxAttempts = 3 + + t.Run("claims a slot when below cap", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, 0)) + + claimed, err := ss.User().TryIncrementFailedPasswordAttempts(u1.Id, maxAttempts) + require.NoError(t, err) + require.True(t, claimed) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, 1, user.FailedAttempts) + }) + + t.Run("does not claim a slot when at cap", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, maxAttempts)) + + claimed, err := ss.User().TryIncrementFailedPasswordAttempts(u1.Id, maxAttempts) + require.NoError(t, err) + require.False(t, claimed) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, maxAttempts, user.FailedAttempts, "counter must not advance past the cap") + }) + + t.Run("does not claim a slot when above cap", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, maxAttempts+5)) + + claimed, err := ss.User().TryIncrementFailedPasswordAttempts(u1.Id, maxAttempts) + require.NoError(t, err) + require.False(t, claimed) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, maxAttempts+5, user.FailedAttempts) + }) + + t.Run("does not claim a slot for unknown user", func(t *testing.T) { + claimed, err := ss.User().TryIncrementFailedPasswordAttempts(model.NewId(), maxAttempts) + require.NoError(t, err) + require.False(t, claimed) + }) + + t.Run("concurrent attempts cap at maxAttempts", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, 0)) + + const goroutines = 50 + var g errgroup.Group + var claimed atomic.Int64 + start := make(chan struct{}) + for range goroutines { + g.Go(func() error { + <-start + ok, err := ss.User().TryIncrementFailedPasswordAttempts(u1.Id, maxAttempts) + if err != nil { + return err + } + if ok { + claimed.Add(1) + } + return nil + }) + } + close(start) + require.NoError(t, g.Wait()) + + require.Equal(t, int64(maxAttempts), claimed.Load(), "exactly maxAttempts goroutines must have claimed a slot") + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, maxAttempts, user.FailedAttempts) + }) +} + +func testUserStoreDecrementFailedPasswordAttempts(t *testing.T, rctx request.CTX, ss store.Store) { + u1 := &model.User{} + u1.Email = MakeEmail() + _, err := ss.User().Save(rctx, u1) + require.NoError(t, err) + defer func() { require.NoError(t, ss.User().PermanentDelete(rctx, u1.Id)) }() + _, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: model.NewId(), UserId: u1.Id}, -1) + require.NoError(t, nErr) + + t.Run("decrements when above zero", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, 3)) + + require.NoError(t, ss.User().DecrementFailedPasswordAttempts(u1.Id)) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, 2, user.FailedAttempts) + }) + + t.Run("does not go below zero", func(t *testing.T) { + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, 0)) + + require.NoError(t, ss.User().DecrementFailedPasswordAttempts(u1.Id)) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, 0, user.FailedAttempts) + }) + + t.Run("no-op for unknown user", func(t *testing.T) { + require.NoError(t, ss.User().DecrementFailedPasswordAttempts(model.NewId())) + }) + + t.Run("concurrent decrements never go below zero", func(t *testing.T) { + const initial = 10 + const goroutines = 50 + require.NoError(t, ss.User().UpdateFailedPasswordAttempts(u1.Id, initial)) + + var g errgroup.Group + start := make(chan struct{}) + for range goroutines { + g.Go(func() error { + <-start + return ss.User().DecrementFailedPasswordAttempts(u1.Id) + }) + } + close(start) + require.NoError(t, g.Wait()) + + user, err := ss.User().Get(context.Background(), u1.Id) + require.NoError(t, err) + require.Equal(t, 0, user.FailedAttempts, "decrement must clamp at zero under contention") + }) +} + func testUserStoreGet(t *testing.T, rctx request.CTX, ss store.Store) { u1 := &model.User{ Email: MakeEmail(), diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index b1b4e1ed635..6fa6f706031 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -13773,6 +13773,38 @@ func (s *TimerLayerUserStore) UpdateFailedPasswordAttempts(userID string, attemp return err } +func (s *TimerLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { + start := time.Now() + + result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.TryIncrementFailedPasswordAttempts", success, elapsed) + } + return result, err +} + +func (s *TimerLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { + start := time.Now() + + err := s.UserStore.DecrementFailedPasswordAttempts(userID) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.DecrementFailedPasswordAttempts", success, elapsed) + } + return err +} + func (s *TimerLayerUserStore) UpdateLastLogin(userID string, lastLogin int64) error { start := time.Now() From 669eb104c60c82e3e3eed2b18d1e7c64877aa71e Mon Sep 17 00:00:00 2001 From: Miguel de la Cruz Date: Mon, 18 May 2026 12:25:05 +0200 Subject: [PATCH 18/80] Fix webhook list ordering instability when paginating (MM-65732) (#36470) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix webhook list ordering instability when paginating (MM-65732) The webhook list view reorders entries when navigating between pages. The first page initially shows webhooks in insertion order (from the server), but after loading additional pages the display settles into alphabetical order. Going back to page 1 then shows different items than were originally visible. Root causes: 1. Server: GetIncomingByTeamByUser, GetIncomingListByUser, GetOutgoingByTeamByUser, and GetOutgoingListByUser had no ORDER BY clause, so the database could return rows in any order. 2. Client (incoming webhooks): incomingWebhookCompare only resolved the channel-name fallback for the 'a' argument, not 'b', making the comparator asymmetric and producing an unstable sort. 3. Client: both installed_incoming_webhooks and installed_outgoing_webhooks called Array.prototype.sort() directly on the props array, mutating it. Fix: - Add ORDER BY DisplayName, Id to the four listing SQL queries so API pages always come back in alphabetical order. With a stable server order, the client sort over merged pages produces the same slice for each page number regardless of how many pages have been loaded. - Symmetrise incomingWebhookCompare by applying the same channel-name and 'Private Webhook' fallback to the 'b' argument. - Sort a copy ([...hooks].sort()) in both webhook list components so the original prop arrays are never mutated. Co-authored-by: Miguel de la Cruz * Fix lint: remove space before JSX closing tag in webhook test Co-authored-by: Miguel de la Cruz * Fold ordering tests into existing webhook store test functions Instead of four separate top-level test registrations (GetIncomingListByUserOrdering, etc.), each ordering assertion is now a t.Run sub-test inside its corresponding existing function: testWebhookStoreGetIncomingListByUser └─ "GetIncomingListByUser, ordered alphabetically by display name" TestWebhookStoreGetIncomingByTeamByUser └─ "GetIncomingByTeamByUser, ordered alphabetically by display name" testWebhookStoreGetOutgoingListByUser └─ "GetOutgoingListByUser, ordered alphabetically by display name" testWebhookStoreGetOutgoingByTeamByUser └─ "GetOutgoingByTeamByUser, ordered alphabetically by display name" Each sub-test creates fresh hooks (Charlie, Alpha, Bravo in insertion order) scoped to its own IDs so they do not interfere with the outer test's fixtures. Co-authored-by: Miguel de la Cruz * Fix govet shadow and gofmt issues in webhook store tests - Rename the outer err variable to errSave in three functions (testWebhookStoreGetIncomingListByUser, TestWebhookStoreGetIncomingByTeamByUser, testWebhookStoreGetOutgoingByTeamByUser) so that hooks, err := declarations in sub-test closures no longer shadow it. - Change hookC, err = to hookC, err := in each ordering sub-test to declare a local err instead of capturing the outer one. - Remove a trailing blank line at the end of the file (gofmt). Co-authored-by: Miguel de la Cruz * Remove jest.mock from installed_incoming_webhooks test The mock for delete_integration_link was copied from the outgoing webhooks list test but is not needed here: the real component renders fine within renderWithContext (as shown by installed_incoming_webhook.test.tsx which tests the individual item without any mocks). Since the ordering tests do not interact with delete functionality, drop the mock and align the action stub style to mockReturnValue(Promise.resolve()). Co-authored-by: Miguel de la Cruz * Add ORDER BY to GetOutgoingByChannelByUser; add ordering sub-test GetOutgoingByChannelByUser was the last paginated webhook listing function without an ORDER BY clause. Add OrderBy("DisplayName", "Id") consistent with all other listing functions. Add the corresponding ordering sub-test inside testWebhookStoreGetOutgoingByChannelByUser, following the same errSave pattern established for the other functions to avoid govet shadow warnings. Co-authored-by: Miguel de la Cruz --------- Co-authored-by: Cursor Agent Co-authored-by: Miguel de la Cruz --- .../channels/store/sqlstore/webhook_store.go | 9 +- .../channels/store/storetest/webhook_store.go | 139 ++++++++++++++-- .../installed_incoming_webhooks.test.tsx | 154 ++++++++++++++++++ .../installed_incoming_webhooks.tsx | 13 +- .../installed_outgoing_webhooks.tsx | 2 +- 5 files changed, 294 insertions(+), 23 deletions(-) create mode 100644 webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.test.tsx diff --git a/server/channels/store/sqlstore/webhook_store.go b/server/channels/store/sqlstore/webhook_store.go index d0fc10b122c..4a7857f878f 100644 --- a/server/channels/store/sqlstore/webhook_store.go +++ b/server/channels/store/sqlstore/webhook_store.go @@ -166,6 +166,7 @@ func (s SqlWebhookStore) GetIncomingListByUser(userId string, offset, limit int) query := s.incomingWebhookSelectQuery. Where(sq.Eq{"DeleteAt": 0}). + OrderBy("DisplayName", "Id"). Limit(uint64(limit)). Offset(uint64(offset)) @@ -188,6 +189,7 @@ func (s SqlWebhookStore) GetIncomingByTeamByUser(teamId string, userId string, o sq.Eq{"TeamId": teamId}, sq.Eq{"DeleteAt": 0}, }). + OrderBy("DisplayName", "Id"). Limit(uint64(limit)). Offset(uint64(offset)) @@ -271,6 +273,7 @@ func (s SqlWebhookStore) GetOutgoingListByUser(userId string, offset, limit int) Where(sq.And{ sq.Eq{"DeleteAt": 0}, }). + OrderBy("DisplayName", "Id"). Limit(uint64(limit)). Offset(uint64(offset)) @@ -296,7 +299,8 @@ func (s SqlWebhookStore) GetOutgoingByChannelByUser(channelId string, userId str Where(sq.And{ sq.Eq{"ChannelId": channelId}, sq.Eq{"DeleteAt": 0}, - }) + }). + OrderBy("DisplayName", "Id") if userId != "" { query = query.Where(sq.Eq{"CreatorId": userId}) @@ -323,7 +327,8 @@ func (s SqlWebhookStore) GetOutgoingByTeamByUser(teamId string, userId string, o Where(sq.And{ sq.Eq{"TeamId": teamId}, sq.Eq{"DeleteAt": 0}, - }) + }). + OrderBy("DisplayName", "Id") if userId != "" { query = query.Where(sq.Eq{"CreatorId": userId}) diff --git a/server/channels/store/storetest/webhook_store.go b/server/channels/store/storetest/webhook_store.go index 6de9ddbb1a8..a0884698bce 100644 --- a/server/channels/store/storetest/webhook_store.go +++ b/server/channels/store/storetest/webhook_store.go @@ -132,8 +132,8 @@ func testWebhookStoreGetIncomingListByUser(t *testing.T, rctx request.CTX, ss st o1.UserId = model.NewId() o1.TeamId = model.NewId() - o1, err := ss.Webhook().SaveIncoming(o1) - require.NoError(t, err) + o1, errSave := ss.Webhook().SaveIncoming(o1) + require.NoError(t, errSave) t.Run("GetIncomingListByUser, known user filtered", func(t *testing.T) { hooks, err := ss.Webhook().GetIncomingListByUser(o1.UserId, 0, 100) @@ -147,6 +147,27 @@ func testWebhookStoreGetIncomingListByUser(t *testing.T, rctx request.CTX, ss st require.NoError(t, err) require.Equal(t, 0, len(hooks)) }) + + t.Run("GetIncomingListByUser, ordered alphabetically by display name", func(t *testing.T) { + userId := model.NewId() + hookC := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: model.NewId(), DisplayName: "Charlie"} + hookA := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: model.NewId(), DisplayName: "Alpha"} + hookB := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: model.NewId(), DisplayName: "Bravo"} + + hookC, err := ss.Webhook().SaveIncoming(hookC) + require.NoError(t, err) + hookA, err = ss.Webhook().SaveIncoming(hookA) + require.NoError(t, err) + hookB, err = ss.Webhook().SaveIncoming(hookB) + require.NoError(t, err) + + hooks, err := ss.Webhook().GetIncomingListByUser(userId, 0, 100) + require.NoError(t, err) + require.Len(t, hooks, 3) + require.Equal(t, hookA.Id, hooks[0].Id, "first result should be Alpha (alphabetical order)") + require.Equal(t, hookB.Id, hooks[1].Id, "second result should be Bravo (alphabetical order)") + require.Equal(t, hookC.Id, hooks[2].Id, "third result should be Charlie (alphabetical order)") + }) } func testWebhookStoreGetIncomingByTeam(t *testing.T, rctx request.CTX, ss store.Store) { @@ -166,16 +187,14 @@ func testWebhookStoreGetIncomingByTeam(t *testing.T, rctx request.CTX, ss store. } func TestWebhookStoreGetIncomingByTeamByUser(t *testing.T, rctx request.CTX, ss store.Store) { - var err error - o1 := buildIncomingWebhook() - o1, err = ss.Webhook().SaveIncoming(o1) - require.NoError(t, err) + o1, errSave := ss.Webhook().SaveIncoming(o1) + require.NoError(t, errSave) o2 := buildIncomingWebhook() o2.TeamId = o1.TeamId //Set both to the same team - o2, err = ss.Webhook().SaveIncoming(o2) - require.NoError(t, err) + o2, errSave = ss.Webhook().SaveIncoming(o2) + require.NoError(t, errSave) t.Run("GetIncomingByTeamByUser, no user filter", func(t *testing.T) { hooks, err := ss.Webhook().GetIncomingByTeam(o1.TeamId, 0, 100) @@ -195,6 +214,28 @@ func TestWebhookStoreGetIncomingByTeamByUser(t *testing.T, rctx request.CTX, ss require.NoError(t, err) require.Equal(t, len(hooks), 0) }) + + t.Run("GetIncomingByTeamByUser, ordered alphabetically by display name", func(t *testing.T) { + teamId := model.NewId() + userId := model.NewId() + hookC := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: teamId, DisplayName: "Charlie"} + hookA := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: teamId, DisplayName: "Alpha"} + hookB := &model.IncomingWebhook{ChannelId: model.NewId(), UserId: userId, TeamId: teamId, DisplayName: "Bravo"} + + hookC, err := ss.Webhook().SaveIncoming(hookC) + require.NoError(t, err) + hookA, err = ss.Webhook().SaveIncoming(hookA) + require.NoError(t, err) + hookB, err = ss.Webhook().SaveIncoming(hookB) + require.NoError(t, err) + + hooks, err := ss.Webhook().GetIncomingByTeamByUser(teamId, userId, 0, 100) + require.NoError(t, err) + require.Len(t, hooks, 3) + require.Equal(t, hookA.Id, hooks[0].Id, "first result should be Alpha (alphabetical order)") + require.Equal(t, hookB.Id, hooks[1].Id, "second result should be Bravo (alphabetical order)") + require.Equal(t, hookC.Id, hooks[2].Id, "third result should be Charlie (alphabetical order)") + }) } func testWebhookStoreGetIncomingByChannel(t *testing.T, rctx request.CTX, ss store.Store) { @@ -332,6 +373,27 @@ func testWebhookStoreGetOutgoingListByUser(t *testing.T, rctx request.CTX, ss st require.NoError(t, err) require.Equal(t, 0, len(hooks)) }) + + t.Run("GetOutgoingListByUser, ordered alphabetically by display name", func(t *testing.T) { + creatorId := model.NewId() + hookC := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Charlie"} + hookA := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Alpha"} + hookB := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Bravo"} + + hookC, err := ss.Webhook().SaveOutgoing(hookC) + require.NoError(t, err) + hookA, err = ss.Webhook().SaveOutgoing(hookA) + require.NoError(t, err) + hookB, err = ss.Webhook().SaveOutgoing(hookB) + require.NoError(t, err) + + hooks, err := ss.Webhook().GetOutgoingListByUser(creatorId, 0, 100) + require.NoError(t, err) + require.Len(t, hooks, 3) + require.Equal(t, hookA.Id, hooks[0].Id, "first result should be Alpha (alphabetical order)") + require.Equal(t, hookB.Id, hooks[1].Id, "second result should be Bravo (alphabetical order)") + require.Equal(t, hookC.Id, hooks[2].Id, "third result should be Charlie (alphabetical order)") + }) } func testWebhookStoreGetOutgoingList(t *testing.T, rctx request.CTX, ss store.Store) { @@ -400,8 +462,8 @@ func testWebhookStoreGetOutgoingByChannelByUser(t *testing.T, rctx request.CTX, o1.TeamId = model.NewId() o1.CallbackURLs = []string{"http://nowhere.com/"} - o1, err := ss.Webhook().SaveOutgoing(o1) - require.NoError(t, err) + o1, errSave := ss.Webhook().SaveOutgoing(o1) + require.NoError(t, errSave) o2 := &model.OutgoingWebhook{} o2.ChannelId = o1.ChannelId @@ -409,8 +471,8 @@ func testWebhookStoreGetOutgoingByChannelByUser(t *testing.T, rctx request.CTX, o2.TeamId = model.NewId() o2.CallbackURLs = []string{"http://nowhere.com/"} - _, err = ss.Webhook().SaveOutgoing(o2) - require.NoError(t, err) + _, errSave = ss.Webhook().SaveOutgoing(o2) + require.NoError(t, errSave) t.Run("GetOutgoingByChannelByUser, no user filter", func(t *testing.T) { hooks, err := ss.Webhook().GetOutgoingByChannel(o1.ChannelId, 0, 100) @@ -430,6 +492,27 @@ func testWebhookStoreGetOutgoingByChannelByUser(t *testing.T, rctx request.CTX, require.NoError(t, err) require.Equal(t, 0, len(hooks)) }) + + t.Run("GetOutgoingByChannelByUser, ordered alphabetically by display name", func(t *testing.T) { + channelId := model.NewId() + hookC := &model.OutgoingWebhook{ChannelId: channelId, CreatorId: model.NewId(), TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Charlie"} + hookA := &model.OutgoingWebhook{ChannelId: channelId, CreatorId: model.NewId(), TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Alpha"} + hookB := &model.OutgoingWebhook{ChannelId: channelId, CreatorId: model.NewId(), TeamId: model.NewId(), CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Bravo"} + + hookC, err := ss.Webhook().SaveOutgoing(hookC) + require.NoError(t, err) + hookA, err = ss.Webhook().SaveOutgoing(hookA) + require.NoError(t, err) + hookB, err = ss.Webhook().SaveOutgoing(hookB) + require.NoError(t, err) + + hooks, err := ss.Webhook().GetOutgoingByChannel(channelId, 0, 100) + require.NoError(t, err) + require.Len(t, hooks, 3) + require.Equal(t, hookA.Id, hooks[0].Id, "first result should be Alpha (alphabetical order)") + require.Equal(t, hookB.Id, hooks[1].Id, "second result should be Bravo (alphabetical order)") + require.Equal(t, hookC.Id, hooks[2].Id, "third result should be Charlie (alphabetical order)") + }) } func testWebhookStoreGetOutgoingByTeam(t *testing.T, rctx request.CTX, ss store.Store) { @@ -451,16 +534,14 @@ func testWebhookStoreGetOutgoingByTeam(t *testing.T, rctx request.CTX, ss store. } func testWebhookStoreGetOutgoingByTeamByUser(t *testing.T, rctx request.CTX, ss store.Store) { - var err error - o1 := &model.OutgoingWebhook{} o1.ChannelId = model.NewId() o1.CreatorId = model.NewId() o1.TeamId = model.NewId() o1.CallbackURLs = []string{"http://nowhere.com/"} - o1, err = ss.Webhook().SaveOutgoing(o1) - require.NoError(t, err) + o1, errSave := ss.Webhook().SaveOutgoing(o1) + require.NoError(t, errSave) o2 := &model.OutgoingWebhook{} o2.ChannelId = model.NewId() @@ -468,8 +549,8 @@ func testWebhookStoreGetOutgoingByTeamByUser(t *testing.T, rctx request.CTX, ss o2.TeamId = o1.TeamId o2.CallbackURLs = []string{"http://nowhere.com/"} - o2, err = ss.Webhook().SaveOutgoing(o2) - require.NoError(t, err) + o2, errSave = ss.Webhook().SaveOutgoing(o2) + require.NoError(t, errSave) t.Run("GetOutgoingByTeamByUser, no user filter", func(t *testing.T) { hooks, err := ss.Webhook().GetOutgoingByTeam(o1.TeamId, 0, 100) @@ -489,6 +570,28 @@ func testWebhookStoreGetOutgoingByTeamByUser(t *testing.T, rctx request.CTX, ss require.NoError(t, err) require.Equal(t, len(hooks), 0) }) + + t.Run("GetOutgoingByTeamByUser, ordered alphabetically by display name", func(t *testing.T) { + teamId := model.NewId() + creatorId := model.NewId() + hookC := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: teamId, CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Charlie"} + hookA := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: teamId, CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Alpha"} + hookB := &model.OutgoingWebhook{ChannelId: model.NewId(), CreatorId: creatorId, TeamId: teamId, CallbackURLs: []string{"http://nowhere.com/"}, DisplayName: "Bravo"} + + hookC, err := ss.Webhook().SaveOutgoing(hookC) + require.NoError(t, err) + hookA, err = ss.Webhook().SaveOutgoing(hookA) + require.NoError(t, err) + hookB, err = ss.Webhook().SaveOutgoing(hookB) + require.NoError(t, err) + + hooks, err := ss.Webhook().GetOutgoingByTeamByUser(teamId, creatorId, 0, 100) + require.NoError(t, err) + require.Len(t, hooks, 3) + require.Equal(t, hookA.Id, hooks[0].Id, "first result should be Alpha (alphabetical order)") + require.Equal(t, hookB.Id, hooks[1].Id, "second result should be Bravo (alphabetical order)") + require.Equal(t, hookC.Id, hooks[2].Id, "third result should be Charlie (alphabetical order)") + }) } func testWebhookStoreDeleteOutgoing(t *testing.T, rctx request.CTX, ss store.Store) { diff --git a/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.test.tsx b/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.test.tsx new file mode 100644 index 00000000000..b0ac74c7d8a --- /dev/null +++ b/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.test.tsx @@ -0,0 +1,154 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {screen, waitFor} from '@testing-library/react'; +import React from 'react'; + +import type {IncomingWebhook} from '@mattermost/types/integrations'; + +import InstalledIncomingWebhooks from 'components/integrations/installed_incoming_webhooks/installed_incoming_webhooks'; + +import {renderWithContext} from 'tests/react_testing_utils'; +import {TestHelper} from 'utils/test_helper'; + +describe('components/integrations/InstalledIncomingWebhooks', () => { + const team = TestHelper.getTeamMock({id: 'teamId', name: 'test'}); + const user = TestHelper.getUserMock({id: 'userId'}); + const channel = TestHelper.getChannelMock({ + id: 'channelId', + display_name: 'Town Square', + }); + + const hookAlpha: IncomingWebhook = TestHelper.getIncomingWebhookMock({ + id: 'hook-alpha', + display_name: 'Alpha Webhook', + channel_id: 'channelId', + team_id: 'teamId', + user_id: 'userId', + }); + const hookCharlie: IncomingWebhook = TestHelper.getIncomingWebhookMock({ + id: 'hook-charlie', + display_name: 'Charlie Webhook', + channel_id: 'channelId', + team_id: 'teamId', + user_id: 'userId', + }); + const hookBravo: IncomingWebhook = TestHelper.getIncomingWebhookMock({ + id: 'hook-bravo', + display_name: 'Bravo Webhook', + channel_id: 'channelId', + team_id: 'teamId', + user_id: 'userId', + }); + + const initialState = { + entities: { + general: {config: {}}, + users: {currentUserId: 'userId'}, + }, + }; + + const defaultProps = { + team, + user, + incomingHooks: [hookAlpha, hookCharlie, hookBravo], + incomingHooksTotalCount: 3, + channels: {channelId: channel}, + users: {userId: user}, + canManageOthersWebhooks: true, + enableIncomingWebhooks: true, + actions: { + removeIncomingHook: jest.fn(), + loadIncomingHooksAndProfilesForTeam: jest.fn().mockReturnValue(Promise.resolve()), + }, + }; + + test('renders webhooks sorted alphabetically by display name', async () => { + renderWithContext( + , + initialState, + ); + + await waitFor(() => { + expect(screen.getByText('Alpha Webhook')).toBeInTheDocument(); + }); + + const items = screen.getAllByText(/Webhook/); + const names = items.map((el) => el.textContent); + + const alphaIdx = names.findIndex((n) => n?.includes('Alpha')); + const bravoIdx = names.findIndex((n) => n?.includes('Bravo')); + const charlieIdx = names.findIndex((n) => n?.includes('Charlie')); + + expect(alphaIdx).toBeLessThan(bravoIdx); + expect(bravoIdx).toBeLessThan(charlieIdx); + }); + + test('does not mutate the incomingHooks prop array when sorting', async () => { + const hooks: IncomingWebhook[] = [hookAlpha, hookCharlie, hookBravo]; + const originalOrder = hooks.map((h) => h.id); + + const props = {...defaultProps, incomingHooks: hooks}; + + renderWithContext( + , + initialState, + ); + + await waitFor(() => { + expect(screen.getByText('Alpha Webhook')).toBeInTheDocument(); + }); + + // The original array passed as prop must not be mutated by the sort + expect(hooks.map((h) => h.id)).toEqual(originalOrder); + }); + + test('compares hooks with missing display_name symmetrically using channel name fallback', async () => { + const noNameHook: IncomingWebhook = TestHelper.getIncomingWebhookMock({ + id: 'hook-no-name', + display_name: '', + channel_id: 'channelId', + team_id: 'teamId', + user_id: 'userId', + }); + const namedHook: IncomingWebhook = TestHelper.getIncomingWebhookMock({ + id: 'hook-named', + display_name: 'Zeta Webhook', + channel_id: 'channelId', + team_id: 'teamId', + user_id: 'userId', + }); + + // channel display_name is "Town Square" which sorts before "Zeta Webhook" + const props = { + ...defaultProps, + incomingHooks: [namedHook, noNameHook], + incomingHooksTotalCount: 2, + }; + + renderWithContext( + , + initialState, + ); + + await waitFor(() => { + expect(screen.getByText('Zeta Webhook')).toBeInTheDocument(); + }); + + const townSquareEl = screen.getByText('Town Square'); + const zetaEl = screen.getByText('Zeta Webhook'); + + expect(townSquareEl).toBeInTheDocument(); + expect(zetaEl).toBeInTheDocument(); + + // Verify DOM order: Town Square (channel fallback) should appear before Zeta Webhook + const position = townSquareEl.compareDocumentPosition(zetaEl); + expect(position & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + }); +}); diff --git a/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.tsx b/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.tsx index 137c3cb57e5..cd8ea6a61b0 100644 --- a/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.tsx +++ b/webapp/channels/src/components/integrations/installed_incoming_webhooks/installed_incoming_webhooks.tsx @@ -95,11 +95,20 @@ export default class InstalledIncomingWebhooks extends React.PureComponent this.props.incomingHooks. + incomingWebhooks = (filter: string) => [...this.props.incomingHooks]. sort(this.incomingWebhookCompare). filter((incomingWebhook: IncomingWebhook) => matchesFilter(incomingWebhook, this.props.channels[incomingWebhook.channel_id], filter)). map((incomingWebhook: IncomingWebhook) => { diff --git a/webapp/channels/src/components/integrations/installed_outgoing_webhooks/installed_outgoing_webhooks.tsx b/webapp/channels/src/components/integrations/installed_outgoing_webhooks/installed_outgoing_webhooks.tsx index 0da6ca38461..5d5eafd9756 100644 --- a/webapp/channels/src/components/integrations/installed_outgoing_webhooks/installed_outgoing_webhooks.tsx +++ b/webapp/channels/src/components/integrations/installed_outgoing_webhooks/installed_outgoing_webhooks.tsx @@ -136,7 +136,7 @@ export default class InstalledOutgoingWebhooks extends React.PureComponent this.props.outgoingWebhooks. + outgoingWebhooks = (filter: string) => [...this.props.outgoingWebhooks]. sort(this.outgoingWebhookCompare). filter((outgoingWebhook) => matchesFilter(outgoingWebhook, this.props.channels[outgoingWebhook.channel_id], filter)). map((outgoingWebhook) => { From 9d0615554077ca123057e9bab158f94f2e0c52a9 Mon Sep 17 00:00:00 2001 From: Miguel de la Cruz Date: Mon, 18 May 2026 12:27:38 +0200 Subject: [PATCH 19/80] Update bot checks (#36503) * Fix bot permission checks in revokeSession, revokeAllSessionsForUser, and updatePassword MM-68701: Align permission checks with the bot-aware pattern used by updateUser, patchUser, deleteUser, and (via MM-68686) updateUserActive. Three handlers were missing the IsBot branch: - revokeSession / revokeAllSessionsForUser: both gated access through SessionHasPermissionToUser, which only requires EditOtherUsers (an ancillary permission granted to User Managers). Switching to SessionHasPermissionToUserOrBot routes bot targets through SessionHasPermissionToManageBot first and falls back to the user path only when the target is not a bot. - updatePassword: the permission flag canUpdatePassword was set by checking PermissionSysconsoleWriteUserManagementUsers (or PermissionManageSystem for system admins) with no IsBot branch. Adding an else-if user.IsBot guard routes bot targets through SessionHasPermissionToManageBot, consistent with every other handler in the file that touches bot accounts. Co-authored-by: Miguel de la Cruz * Improve TestRevokeSessionBotPermissions: revoke a real bot session Seed a session directly via th.App.CreateSession instead of passing a fake ID and expecting a 400. The test now validates the full happy path: the session row is created, the privileged user revokes it, and the call returns 200 OK. Co-authored-by: Miguel de la Cruz * Strengthen forbidden sub-test: revoke a real bot session with no perms Seed a real session for the bot before the unprivileged revoke call. The test now proves the permission gate blocks access even when the target session ID genuinely exists in the database. Co-authored-by: Miguel de la Cruz * Address review feedback: add post-conditions to bot session revoke tests - TestRevokeSessionBotPermissions: after RevokeSession succeeds, assert GetSessionById returns an error to confirm the row is gone. - TestRevokeAllSessionsForUserBotPermissions: seed a real session before RevokeAllSessions so the call is not a no-op, then assert GetSessions returns an empty list afterwards. Co-authored-by: Miguel de la Cruz --------- Co-authored-by: Cursor Agent Co-authored-by: Miguel de la Cruz Co-authored-by: Mattermost Build --- server/channels/api4/user.go | 6 +- server/channels/api4/user_test.go | 149 ++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/server/channels/api4/user.go b/server/channels/api4/user.go index 7a5aec2426d..2f4d96571eb 100644 --- a/server/channels/api4/user.go +++ b/server/channels/api4/user.go @@ -1955,6 +1955,8 @@ func updatePassword(c *Context, w http.ResponseWriter, r *http.Request) { if user.IsSystemAdmin() { canUpdatePassword = c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionManageSystem) + } else if user.IsBot { + canUpdatePassword = c.App.SessionHasPermissionToManageBot(c.AppContext, *c.AppContext.Session(), c.Params.UserId) == nil } else { canUpdatePassword = c.App.SessionHasPermissionTo(*c.AppContext.Session(), model.PermissionSysconsoleWriteUserManagementUsers) } @@ -2527,7 +2529,7 @@ func revokeSession(c *Context, w http.ResponseWriter, r *http.Request) { auditRec := c.MakeAuditRecord(model.AuditEventRevokeSession, model.AuditStatusFail) defer c.LogAuditRec(auditRec) - if !c.App.SessionHasPermissionToUser(*c.AppContext.Session(), c.Params.UserId) { + if !c.App.SessionHasPermissionToUserOrBot(c.AppContext, *c.AppContext.Session(), c.Params.UserId) { c.SetPermissionError(model.PermissionEditOtherUsers) return } @@ -2575,7 +2577,7 @@ func revokeAllSessionsForUser(c *Context, w http.ResponseWriter, r *http.Request defer c.LogAuditRec(auditRec) model.AddEventParameterToAuditRec(auditRec, "user_id", c.Params.UserId) - if !c.App.SessionHasPermissionToUser(*c.AppContext.Session(), c.Params.UserId) { + if !c.App.SessionHasPermissionToUserOrBot(c.AppContext, *c.AppContext.Session(), c.Params.UserId) { c.SetPermissionError(model.PermissionEditOtherUsers) return } diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index 127aa75a9ed..fddc86cb308 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -4512,6 +4512,61 @@ func TestRevokeSessions(t *testing.T) { require.NoError(t, err) } +func TestRevokeSessionBotPermissions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableBotAccountCreation = true + }) + + bot, botResp, err := th.SystemAdminClient.CreateBot(context.Background(), &model.Bot{ + Username: GenerateTestUsername(), + DisplayName: "Test Bot", + Description: "bot for revoke-session permission test", + }) + require.NoError(t, err) + CheckCreatedStatus(t, botResp) + defer func() { + appErr := th.App.PermanentDeleteBot(th.Context, bot.UserId) + assert.Nil(t, appErr) + }() + + t.Run("user manager without bot permissions cannot revoke bot session", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + + // Seed a real session so the test confirms the permission gate blocks + // access even when the target session genuinely exists. + botSession, appErr := th.App.CreateSession(th.Context, &model.Session{UserId: bot.UserId}) + require.Nil(t, appErr) + + th.LoginBasic(t) + + resp, err := th.Client.RevokeSession(context.Background(), bot.UserId, botSession.Id) + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("user with bot management permissions can revoke bot session", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + + // Seed a real session for the bot directly via the app layer. + botSession, appErr := th.App.CreateSession(th.Context, &model.Session{UserId: bot.UserId}) + require.Nil(t, appErr) + + th.LoginBasic(t) + + _, err := th.Client.RevokeSession(context.Background(), bot.UserId, botSession.Id) + require.NoError(t, err) + + // Confirm the session row is gone. + _, appErr = th.App.GetSessionById(th.Context, botSession.Id) + require.NotNil(t, appErr, "session should no longer exist after revocation") + }) +} + func TestRevokeAllSessions(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) @@ -7405,6 +7460,49 @@ func TestUpdatePassword(t *testing.T) { }) } +func TestUpdatePasswordBotPermissions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableBotAccountCreation = true + }) + + bot, botResp, err := th.SystemAdminClient.CreateBot(context.Background(), &model.Bot{ + Username: GenerateTestUsername(), + DisplayName: "Test Bot", + Description: "bot for update-password permission test", + }) + require.NoError(t, err) + CheckCreatedStatus(t, botResp) + defer func() { + appErr := th.App.PermanentDeleteBot(th.Context, bot.UserId) + assert.Nil(t, appErr) + }() + + t.Run("user manager without bot permissions cannot update bot password", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + + th.LoginBasic(t) + + resp, err := th.Client.UpdatePassword(context.Background(), bot.UserId, "", model.NewTestPassword()) + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("user with bot management permissions can update bot password", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + + th.LoginBasic(t) + + resp, err := th.Client.UpdatePassword(context.Background(), bot.UserId, "", model.NewTestPassword()) + require.NoError(t, err) + CheckOKStatus(t, resp) + }) +} + func TestUpdatePasswordAudit(t *testing.T) { logFile, err := os.CreateTemp("", "adv.log") require.NoError(t, err) @@ -9823,6 +9921,57 @@ func TestRevokeAllSessionsForUser(t *testing.T) { }) } +func TestRevokeAllSessionsForUserBotPermissions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableBotAccountCreation = true + }) + + bot, botResp, err := th.SystemAdminClient.CreateBot(context.Background(), &model.Bot{ + Username: GenerateTestUsername(), + DisplayName: "Test Bot", + Description: "bot for revoke-all-sessions permission test", + }) + require.NoError(t, err) + CheckCreatedStatus(t, botResp) + defer func() { + appErr := th.App.PermanentDeleteBot(th.Context, bot.UserId) + assert.Nil(t, appErr) + }() + + t.Run("user manager without bot permissions cannot revoke all sessions for a bot", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionSysconsoleWriteUserManagementUsers.Id, model.SystemUserRoleId) + + th.LoginBasic(t) + + resp, err := th.Client.RevokeAllSessions(context.Background(), bot.UserId) + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("user with bot management permissions can revoke all sessions for a bot", func(t *testing.T) { + th.AddPermissionToRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + defer th.RemovePermissionFromRole(t, model.PermissionManageOthersBots.Id, model.SystemUserRoleId) + + // Seed a real session so RevokeAllSessions is not a no-op. + _, appErr := th.App.CreateSession(th.Context, &model.Session{UserId: bot.UserId}) + require.Nil(t, appErr) + + th.LoginBasic(t) + + _, err := th.Client.RevokeAllSessions(context.Background(), bot.UserId) + require.NoError(t, err) + + // Confirm all sessions for the bot are gone. + sessions, appErr := th.App.GetSessions(th.Context, bot.UserId) + require.Nil(t, appErr) + require.Empty(t, sessions, "all bot sessions should be revoked") + }) +} + func TestResetPasswordFailedAttempts(t *testing.T) { th := SetupEnterprise(t).InitBasic(t) th.SetupLdapConfig() From f067fcde92eb16ef1bec24a8ddc1856fc3ca7f1d Mon Sep 17 00:00:00 2001 From: Maria A Nunez Date: Mon, 18 May 2026 07:22:16 -0400 Subject: [PATCH 20/80] MM-66339 Hide empty content-flagging "With comment" section in reviewer DM (#36552) * Add Cursor Cloud Agent Docker environment Co-authored-by: Cursor * Fix Cloud Agent enterprise and Docker access Co-authored-by: Cursor * Fix Cloud Agent Go path setup Co-authored-by: Cursor * MM-66339 Stop double-JSON-stringifying content flagging comments The flagPost, removeFlaggedPost, and keepFlaggedPost Client4 helpers were calling JSON.stringify on the comment value before placing it in the JSON request body. When the reporter or reviewer left the optional comment blank, JSON.stringify('') returned the literal two-character string '""', which the server then stored as the comment and embedded in the reviewer DM as 'With comment:\n\n> ""'. Send comment as the plain string instead so an empty comment stays empty and the 'With comment' section is omitted entirely. Co-authored-by: Maria A Nunez --------- Co-authored-by: Nick Misasi Co-authored-by: Cursor --- webapp/platform/client/src/client4.test.ts | 93 ++++++++++++++++++++++ webapp/platform/client/src/client4.ts | 6 +- 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/webapp/platform/client/src/client4.test.ts b/webapp/platform/client/src/client4.test.ts index 01797a1f57a..8aa37520ee5 100644 --- a/webapp/platform/client/src/client4.test.ts +++ b/webapp/platform/client/src/client4.test.ts @@ -111,6 +111,99 @@ describe('Client4', () => { }); }); + describe('content flagging routes', () => { + let client: Client4; + + beforeEach(() => { + client = new Client4(); + client.setUrl('http://mattermost.example.com'); + }); + + test('flagPost should send comment as a plain string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + post('/content_flagging/post/post123/flag', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.flagPost('post123', 'Spam', 'looks suspicious'); + + expect(receivedBody).toEqual({reason: 'Spam', comment: 'looks suspicious'}); + }); + + test('flagPost should preserve an empty comment as an empty string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + post('/content_flagging/post/post123/flag', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.flagPost('post123', 'Spam', ''); + + expect(receivedBody).toEqual({reason: 'Spam', comment: ''}); + }); + + test('removeFlaggedPost should send comment as a plain string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + put('/content_flagging/post/post123/remove', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.removeFlaggedPost('post123', 'violates policy'); + + expect(receivedBody).toEqual({comment: 'violates policy'}); + }); + + test('removeFlaggedPost should preserve an empty comment as an empty string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + put('/content_flagging/post/post123/remove', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.removeFlaggedPost('post123', ''); + + expect(receivedBody).toEqual({comment: ''}); + }); + + test('keepFlaggedPost should send comment as a plain string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + put('/content_flagging/post/post123/keep', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.keepFlaggedPost('post123', 'looks fine'); + + expect(receivedBody).toEqual({comment: 'looks fine'}); + }); + + test('keepFlaggedPost should preserve an empty comment as an empty string', async () => { + let receivedBody: any; + nock(client.getBaseRoute()). + put('/content_flagging/post/post123/keep', (body) => { + receivedBody = body; + return true; + }). + reply(200, {status: 'OK'}); + + await client.keepFlaggedPost('post123', ''); + + expect(receivedBody).toEqual({comment: ''}); + }); + }); + describe('doFetchWithResponse', () => { test('serverVersion should be set from response header', async () => { const client = new Client4(); diff --git a/webapp/platform/client/src/client4.ts b/webapp/platform/client/src/client4.ts index 9d8124e1e29..3ca3afe26e6 100644 --- a/webapp/platform/client/src/client4.ts +++ b/webapp/platform/client/src/client4.ts @@ -5010,7 +5010,7 @@ export default class Client4 { `${this.getContentFlaggingRoute()}/post/${postId}/flag`, { method: 'post', - body: JSON.stringify({reason, comment: JSON.stringify(comment)}), + body: JSON.stringify({reason, comment}), }, ); }; @@ -5020,7 +5020,7 @@ export default class Client4 { `${this.getContentFlaggingRoute()}/post/${postId}/remove`, { method: 'put', - body: JSON.stringify({comment: JSON.stringify(comment)}), + body: JSON.stringify({comment}), }, ); }; @@ -5030,7 +5030,7 @@ export default class Client4 { `${this.getContentFlaggingRoute()}/post/${postId}/keep`, { method: 'put', - body: JSON.stringify({comment: JSON.stringify(comment)}), + body: JSON.stringify({comment}), }, ); }; From bab90098251922b8c7df136b0f3a30c7de7727bc Mon Sep 17 00:00:00 2001 From: Ibrahim Serdar Acikgoz Date: Mon, 18 May 2026 13:47:45 +0200 Subject: [PATCH 21/80] MM-68592: Add leave confirmation modal for policy-added public channels (#36439) * MM-68592: Add leave confirmation modal for policy-added public channels When a user attempts to leave a public channel they were auto-added to via a membership policy (channel.policy_enforced), show a confirmation modal informing them that the leave is permanent and offering a 'Mute instead' option as a lighter alternative. The flow follows the existing pattern used for private channel leave confirmation. The modal is opened from: - Channel header menu Leave action - Sidebar channel menu Leave action - /leave slash command The Mute instead button is hidden when the channel is already muted. Co-authored-by: Ibrahim Serdar Acikgoz * MM-68592: address CodeRabbit review - Make handleMuteInstead async and only close the modal when the mute action resolves successfully, leaving it open on error so the user can retry or choose to leave instead. - Move autoFocus from the destructive 'Leave channel' button to the non-destructive secondary action ('Mute instead' or 'Cancel') so pressing Enter does not default-confirm a permanent leave. - Cover the failure path with a new unit test that asserts the modal remains open when muteChannel returns an error. Co-authored-by: Ibrahim Serdar Acikgoz --------- Co-authored-by: Cursor Agent Co-authored-by: Ibrahim Serdar Acikgoz --- webapp/channels/src/actions/command.ts | 2 +- .../menu_items/leave_channel.test.tsx | 24 ++++ .../menu_items/leave_channel.tsx | 2 +- .../leave_channel_modal.test.tsx.snap | 111 ++++++++++++++++ .../components/leave_channel_modal/index.ts | 24 +++- .../leave_channel_modal.test.tsx | 125 ++++++++++++++++++ .../leave_channel_modal.tsx | 123 ++++++++++++++++- .../sidebar_base_channel.test.tsx | 31 +++++ .../sidebar_base_channel.tsx | 6 +- webapp/channels/src/i18n/en.json | 5 + 10 files changed, 446 insertions(+), 7 deletions(-) diff --git a/webapp/channels/src/actions/command.ts b/webapp/channels/src/actions/command.ts index 41a1b5f746e..3fb49056071 100644 --- a/webapp/channels/src/actions/command.ts +++ b/webapp/channels/src/actions/command.ts @@ -81,7 +81,7 @@ export function executeCommand(message: string, args: CommandArgs): ActionFuncAs if (!channel) { return {data: {silentFailureReason: new Error('cannot find current channel')}}; } - if (channel.type === Constants.PRIVATE_CHANNEL) { + if (channel.type === Constants.PRIVATE_CHANNEL || channel.policy_enforced) { dispatch(openModal({modalId: ModalIdentifiers.LEAVE_PRIVATE_CHANNEL_MODAL, dialogType: LeaveChannelModal, dialogProps: {channel}})); return {data: {frontendHandled: true}}; } diff --git a/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.test.tsx b/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.test.tsx index 7f14137c08e..17b3d9a2a01 100644 --- a/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.test.tsx +++ b/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.test.tsx @@ -66,4 +66,28 @@ describe('components/ChannelHeaderMenu/MenuItems/LeaveChannelTest', () => { }, }); }); + + test('opens leave confirmation modal for a policy enforced public channel', async () => { + const channel = TestHelper.getChannelMock({type: 'O', policy_enforced: true}); + + renderWithContext( + + + , {}, + ); + + const menuItem = screen.getByText('Leave Channel'); + expect(menuItem).toBeInTheDocument(); + + await userEvent.click(menuItem); + expect(channelActions.leaveChannel).not.toHaveBeenCalled(); + expect(modalActions.openModal).toHaveBeenCalledTimes(1); + expect(modalActions.openModal).toHaveBeenCalledWith({ + modalId: ModalIdentifiers.LEAVE_PRIVATE_CHANNEL_MODAL, + dialogType: LeaveChannelModal, + dialogProps: { + channel, + }, + }); + }); }); diff --git a/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.tsx b/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.tsx index 7dc78b001b7..3b1c3010d57 100644 --- a/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.tsx +++ b/webapp/channels/src/components/channel_header_menu/menu_items/leave_channel.tsx @@ -29,7 +29,7 @@ const LeaveChannel = ({ }: Props) => { const dispatch = useDispatch(); const handleLeave = () => { - if (channel.type === Constants.PRIVATE_CHANNEL) { + if (channel.type === Constants.PRIVATE_CHANNEL || channel.policy_enforced) { dispatch( openModal({ modalId: ModalIdentifiers.LEAVE_PRIVATE_CHANNEL_MODAL, diff --git a/webapp/channels/src/components/leave_channel_modal/__snapshots__/leave_channel_modal.test.tsx.snap b/webapp/channels/src/components/leave_channel_modal/__snapshots__/leave_channel_modal.test.tsx.snap index 7880e780fb4..cd3008578c7 100644 --- a/webapp/channels/src/components/leave_channel_modal/__snapshots__/leave_channel_modal.test.tsx.snap +++ b/webapp/channels/src/components/leave_channel_modal/__snapshots__/leave_channel_modal.test.tsx.snap @@ -1,5 +1,116 @@ // Jest Snapshot v1, https://jestjs.io/docs/snapshot-testing +exports[`components/LeaveChannelModal should match snapshot for the policy enforced public channel variant 1`] = ` + + diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/body_main_action_text.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/body_main_action_text.tsx new file mode 100644 index 00000000000..008f24716ee --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/body_main_action_text.tsx @@ -0,0 +1,62 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {useIntl} from 'react-intl'; + +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import AtMention from 'components/at_mention'; +import {useChannel} from 'components/common/hooks/useChannel'; +import {useUser} from 'components/common/hooks/useUser'; + +type Props = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; +}; + +export default function BodyMainActionText({ + action, + flaggedPost, + reportingUser, +}: Props) { + const {formatMessage} = useIntl(); + const flaggedPostAuthor = useUser(flaggedPost.user_id); + const flaggedPostChannel = useChannel(flaggedPost.channel_id); + + const values = { + flaggedPostChannel: flaggedPostChannel?.display_name, + reportingUser: ( + + ), + flaggedPostAuthor: ( + + ), + }; + + let body; + + if (action === 'remove') { + body = formatMessage( + { + id: 'keep_remove_quarantined_content_modal.action_remove.body', + defaultMessage: + 'You are about to remove a message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.', + }, + values, + ); + } else { + body = formatMessage( + { + id: 'keep_remove_quarantined_content_modal.action_keep.body', + defaultMessage: + 'You are about to keep a quarantined message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.', + }, + values, + ); + } + + return

{body}

; +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.scss new file mode 100644 index 00000000000..ba4e7a3d952 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.scss @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal .body { + .ErrorStepBody { + display: flex; + flex-direction: column; + gap: 8px; + + .errorRetryBtn { + width: min-content; + } + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.tsx new file mode 100644 index 00000000000..75ca9451ca2 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_body.tsx @@ -0,0 +1,74 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {FormattedMessage, useIntl} from 'react-intl'; + +import type {ContentFlaggingConfig} from '@mattermost/types/content_flagging'; +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import FlaggedMessageBody from '../flagged_message_body'; +import ReportNotice from '../report_notice'; + +import './error_step_body.scss'; + +type BodyProps = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; + contentFlaggingConfig: ContentFlaggingConfig | undefined; + onRetry: () => void; +}; + +export default function ErrorStepBody({ + action, + flaggedPost, + reportingUser, + contentFlaggingConfig, + onRetry, +}: BodyProps) { + const {formatMessage} = useIntl(); + const tryAgainText = formatMessage({ + id: 'keep_remove_quarantined_content_modal.try_again.button_text', + defaultMessage: 'Try again', + }); + + return ( + <> + + } + title={ + + } + body={ +
+ + +
+ } + /> + + ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_footer.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_footer.tsx new file mode 100644 index 00000000000..49dd6a6dabd --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/error_step/error_step_footer.tsx @@ -0,0 +1,57 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import classNames from 'classnames'; +import React from 'react'; +import {useIntl} from 'react-intl'; + +type FooterProps = { + action: 'keep' | 'remove'; + onSkip: () => void; + onBack: () => void; +}; + +export default function ErrorStepFooter({action, onSkip, onBack}: FooterProps) { + const {formatMessage} = useIntl(); + + const skipText = formatMessage({id: 'keep_remove_quarantined_content_modal.skip_report_download.button_text', defaultMessage: 'Skip report download'}); + const removePermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.permanent_button_text', defaultMessage: 'Remove permanently'}); + const keepPermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.permanent_button_text', defaultMessage: 'Keep permanently'}); + const backText = formatMessage({id: 'keep_remove_quarantined_content_modal.back.button_text', defaultMessage: 'Back'}); + + const permanentText = action === 'remove' ? removePermanentlyText : keepPermanentlyText; + const permanentClass = action === 'remove' ? 'btn-danger' : 'btn-primary'; + + return ( +
+
+ +
+
+ + +
+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.scss new file mode 100644 index 00000000000..c231f50bc78 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.scss @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal .body { + .section.message_body { + p { + margin: 0 0 12px; + + &:last-child { + margin-bottom: 0; + } + } + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.tsx new file mode 100644 index 00000000000..32ff397b990 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/flagged_message_body.tsx @@ -0,0 +1,62 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React, {useMemo} from 'react'; +import {useIntl} from 'react-intl'; + +import type {ContentFlaggingConfig} from '@mattermost/types/content_flagging'; +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import BodyMainActionText from 'components/remove_flagged_message_confirmation_modal/body_main_action_text'; + +import './flagged_message_body.scss'; + +type Props = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; + contentFlaggingConfig: ContentFlaggingConfig | undefined; +}; + +export default function FlaggedMessageBody({action, flaggedPost, reportingUser, contentFlaggingConfig}: Props) { + const {formatMessage} = useIntl(); + + const subtext = useMemo(() => { + if (action === 'remove') { + if (contentFlaggingConfig?.notify_reporter_on_removal) { + return formatMessage({ + id: 'keep_remove_quarantined_content_modal.action_remove.subtext.notify_reporter', + defaultMessage: 'If you confirm, the message will be removed from the channel and a notification will be sent to the reporter. This action cannot be reverted.', + }); + } + return formatMessage({ + id: 'keep_remove_quarantined_content_modal.action_remove.subtext.no_notify_reporter', + defaultMessage: 'If you confirm, the message will be removed from the channel. This action cannot be reverted.', + }); + } else if (contentFlaggingConfig?.notify_reporter_on_dismissal) { + return formatMessage({ + id: 'keep_remove_quarantined_content_modal.action_keep.subtext.notify_reporter', + defaultMessage: 'If you confirm, the message will be visible to all channel members and a notification will be sent to the reporter.', + }); + } + return formatMessage({ + id: 'keep_remove_quarantined_content_modal.action_keep.subtext.no_notify_reporter', + defaultMessage: 'If you confirm, the message will be visible to all channel members.', + }); + }, [action, contentFlaggingConfig?.notify_reporter_on_dismissal, contentFlaggingConfig?.notify_reporter_on_removal, formatMessage]); + + return ( +
+ +

{subtext}

+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.scss new file mode 100644 index 00000000000..1fa0ed7ca76 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.scss @@ -0,0 +1,30 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal .body { + .section { + &.comment_section { + display: flex; + flex-direction: column; + gap: 8px; + } + + .section_title { + color: var(--center-channel-color); + font-size: 14px; + font-weight: 600; + } + } + + button#PreviewInputTextButton { + position: absolute; + z-index: 2; + top: 8px; + right: 8px; + } + + textarea#RemoveFlaggedMessageConfirmationModal__comment { + min-height: 90px !important; + max-height: 400px; + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.tsx new file mode 100644 index 00000000000..95c4d0aeffe --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_body.tsx @@ -0,0 +1,83 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {useIntl} from 'react-intl'; + +import type {ContentFlaggingConfig} from '@mattermost/types/content_flagging'; +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import type {TextboxElement} from 'components/textbox'; +import AdvancedTextbox from 'components/widgets/advanced_textbox/advanced_textbox'; + +import FlaggedMessageBody from '../flagged_message_body'; + +import './form_step_body.scss'; + +type BodyProps = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; + contentFlaggingConfig: ContentFlaggingConfig | undefined; + comment: string; + commentError: string; + showCommentPreview: boolean; + onCommentChange: (e: React.ChangeEvent) => void; + onToggleCommentPreview: () => void; +}; + +export function FormStepBody({ + action, + flaggedPost, + reportingUser, + contentFlaggingConfig, + comment, + commentError, + showCommentPreview, + onCommentChange, + onToggleCommentPreview, +}: BodyProps) { + const {formatMessage} = useIntl(); + + const requiredTitle = formatMessage({id: 'remove_flag_post_confirm_modal.required_comment.title', defaultMessage: 'Comment (required)'}); + const optionalTitle = formatMessage({id: 'remove_flag_post_confirm_modal.optional_comment.title', defaultMessage: 'Comment (optional)'}); + const sectionTitle = contentFlaggingConfig?.reviewer_comment_required ? requiredTitle : optionalTitle; + + const commentPlaceholder = formatMessage({id: 'keep_remove_quarantined_content_modal.comment.placeholder', defaultMessage: 'Add your comment here'}); + + return ( + <> + + +
+
+ {sectionTitle} +
+ + {}} + hasError={false} + errorMessage={commentError} + maxLength={1000} + /> +
+ + ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.scss new file mode 100644 index 00000000000..ae40347f1e3 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.scss @@ -0,0 +1,21 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal { + .download_report_checkbox { + display: flex; + align-items: center; + margin: 0; + cursor: pointer; + font-size: 16px; + font-weight: 400; + gap: 10px; + + input[type='checkbox'] { + width: 20px; + height: 20px; + margin: 0; + cursor: pointer; + } + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.tsx new file mode 100644 index 00000000000..45fce09c172 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/form_step/form_step_footer.tsx @@ -0,0 +1,82 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import classNames from 'classnames'; +import React from 'react'; +import {FormattedMessage, useIntl} from 'react-intl'; + +import './form_step_footer.scss'; + +type FooterProps = { + action: 'keep' | 'remove'; + downloadReport: boolean; + submitting: boolean; + onToggleDownloadReport: (e: React.ChangeEvent) => void; + onCancel: () => void; + onPrimary: () => void; +}; + +export function FormStepFooter({ + action, + downloadReport, + submitting, + onToggleDownloadReport, + onCancel, + onPrimary, +}: FooterProps) { + const {formatMessage} = useIntl(); + + const cancelText = formatMessage({id: 'generic_modal.cancel', defaultMessage: 'Cancel'}); + const continueText = formatMessage({id: 'keep_remove_quarantined_content_modal.continue.button_text', defaultMessage: 'Continue'}); + const removeText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.button_text', defaultMessage: 'Remove message'}); + const keepText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.button_text', defaultMessage: 'Keep message'}); + + const actionText = action === 'remove' ? removeText : keepText; + const uncheckedClass = action === 'remove' ? 'btn-danger' : 'btn-primary'; + + const primaryText = downloadReport ? continueText : actionText; + const primaryClass = downloadReport ? 'btn-primary' : uncheckedClass; + + return ( +
+
+ +
+
+ + +
+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_body.tsx new file mode 100644 index 00000000000..9f9daa9a0bb --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_body.tsx @@ -0,0 +1,61 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {FormattedMessage} from 'react-intl'; + +import type {ContentFlaggingConfig} from '@mattermost/types/content_flagging'; +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import FlaggedMessageBody from '../flagged_message_body'; +import ReportNotice from '../report_notice'; + +type BodyProps = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; + contentFlaggingConfig: ContentFlaggingConfig | undefined; +}; + +export default function GeneratedStepBody({ + action, + flaggedPost, + reportingUser, + contentFlaggingConfig, +}: BodyProps) { + return ( + <> + + } + title={ + + } + body={ + action === 'remove' ? ( + + ) : ( + + ) + } + /> + + ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_footer.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_footer.tsx new file mode 100644 index 00000000000..379e3221395 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generated_step/generated_step_footer.tsx @@ -0,0 +1,66 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import classNames from 'classnames'; +import React from 'react'; +import {useIntl} from 'react-intl'; + +type FooterProps = { + action: 'keep' | 'remove'; + submitting: boolean; + onDownloadAgain: () => void; + onBack: () => void; + onPermanent: () => void; +}; + +export default function GeneratedStepFooter({ + action, + submitting, + onDownloadAgain, + onBack, + onPermanent, +}: FooterProps) { + const {formatMessage} = useIntl(); + + const downloadAgainText = formatMessage({id: 'keep_remove_quarantined_content_modal.download_again.button_text', defaultMessage: 'Download again'}); + const backText = formatMessage({id: 'keep_remove_quarantined_content_modal.back.button_text', defaultMessage: 'Back'}); + const removePermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.permanent_button_text', defaultMessage: 'Remove permanently'}); + const keepPermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.permanent_button_text', defaultMessage: 'Keep permanently'}); + + const permanentText = action === 'remove' ? removePermanentlyText : keepPermanentlyText; + const permanentClass = action === 'remove' ? 'btn-danger' : 'btn-primary'; + + return ( +
+
+ +
+
+ + +
+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_body.tsx new file mode 100644 index 00000000000..5fef111efbd --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_body.tsx @@ -0,0 +1,63 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {FormattedMessage} from 'react-intl'; + +import type {ContentFlaggingConfig} from '@mattermost/types/content_flagging'; +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import LoadingSpinner from 'components/widgets/loading/loading_spinner'; + +import FlaggedMessageBody from '../flagged_message_body'; +import ReportNotice from '../report_notice'; + +type BodyProps = { + action: 'keep' | 'remove'; + flaggedPost: Post; + reportingUser: UserProfile; + contentFlaggingConfig: ContentFlaggingConfig | undefined; +}; + +export function GeneratingStepBody({ + action, + flaggedPost, + reportingUser, + contentFlaggingConfig, +}: BodyProps) { + return ( + <> + + } + title={ + + } + body={ + action === 'remove' ? ( + + ) : ( + + ) + } + /> + + ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_footer.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_footer.tsx new file mode 100644 index 00000000000..e3583df9503 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/generating_step/generating_step_footer.tsx @@ -0,0 +1,58 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import classNames from 'classnames'; +import React from 'react'; +import {useIntl} from 'react-intl'; + +type FooterProps = { + action: 'keep' | 'remove'; + onSkip: () => void; + onBack: () => void; +}; + +export function GeneratingStepFooter({action, onSkip, onBack}: FooterProps) { + const {formatMessage} = useIntl(); + + const skipText = formatMessage({id: 'keep_remove_quarantined_content_modal.skip_report_download.button_text', defaultMessage: 'Skip report download'}); + const backText = formatMessage({id: 'keep_remove_quarantined_content_modal.back.button_text', defaultMessage: 'Back'}); + const removePermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.permanent_button_text', defaultMessage: 'Remove permanently'}); + const keepPermanentlyText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.permanent_button_text', defaultMessage: 'Keep permanently'}); + + const permanentText = + action === 'remove' ? removePermanentlyText : keepPermanentlyText; + const permanentClass = action === 'remove' ? 'btn-danger' : 'btn-primary'; + + return ( +
+
+ +
+
+ + +
+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.scss index 789c131481c..e48dbe5cd64 100644 --- a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.scss +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.scss @@ -2,41 +2,22 @@ // See LICENSE.txt for license information. .KeepRemoveFlaggedMessageConfirmationModal { + width: 704px; + .modal-content { width: 704px; } + .modal-footer { + // Allow the custom footer row to span the full width and use space-between layout. + justify-content: stretch; + } + .body { display: flex; flex-direction: column; gap: 24px; - .section{ - &.comment_section { - display: flex; - flex-direction: column; - gap: 8px; - } - - .section_title { - color: var(--center-channel-color); - font-size: 14px; - font-weight: 600; - } - } - - button#PreviewInputTextButton { - position: absolute; - z-index: 2; - top: 8px; - right: 8px; - } - - textarea#RemoveFlaggedMessageConfirmationModal__comment { - min-height: 90px !important; - max-height: 400px; - } - .request_error { display: flex; width: 90%; @@ -45,11 +26,24 @@ font-size: 12px; } } + + .ModalFooterRow { + display: flex; + width: 100%; + align-items: center; + justify-content: space-between; + gap: 8px; + + &__left, + &__right { + display: flex; + align-items: center; + gap: 8px; + } + + .skipReportBtn, + .skipReportBtn:hover { + color: var(--error-text); + } + } } - - - - - - - diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.test.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.test.tsx index 7d927db3a82..485c3a2407a 100644 --- a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.test.tsx +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.test.tsx @@ -52,6 +52,17 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { const onExited = jest.fn(); + let originalCreateObjectURL: typeof URL.createObjectURL; + let originalRevokeObjectURL: typeof URL.revokeObjectURL; + + const mockReportSuccess = () => { + Client4.generateFlaggedPostReport = jest.fn().mockResolvedValue(new Blob(['report'], {type: 'application/zip'})); + }; + + const mockReportFailure = () => { + Client4.generateFlaggedPostReport = jest.fn().mockRejectedValue(new Error('boom')); + }; + beforeEach(() => { jest.clearAllMocks(); @@ -63,10 +74,22 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { Client4.removeFlaggedPost = jest.fn().mockResolvedValue({}); Client4.keepFlaggedPost = jest.fn().mockResolvedValue({}); + mockReportSuccess(); + originalCreateObjectURL = URL.createObjectURL; + originalRevokeObjectURL = URL.revokeObjectURL; + URL.createObjectURL = jest.fn().mockReturnValue('blob:mock-url'); + URL.revokeObjectURL = jest.fn(); + + // eslint-disable-next-line no-console console.error = jest.fn(); }); + afterEach(() => { + URL.createObjectURL = originalCreateObjectURL; + URL.revokeObjectURL = originalRevokeObjectURL; + }); + describe('remove action', () => { test('should render modal with remove action content', () => { renderWithContext( @@ -80,7 +103,9 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { expect(screen.getByTestId('keep-remove-flagged-message-confirmation-modal')).toBeVisible(); expect(screen.getByRole('heading', {name: 'Remove message from channel'})).toBeVisible(); - expect(screen.getByRole('button', {name: 'Remove message'})).toBeVisible(); + + // Default form step shows the "Continue" primary button (download checkbox is on by default) + expect(screen.getByRole('button', {name: 'Continue'})).toBeVisible(); }); test('should show notification subtext when notify_reporter_on_removal is true', () => { @@ -118,7 +143,7 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { expect(subtext).toHaveTextContent(/the message will be removed from the channel. This action cannot be reverted./); }); - test('should call Client4.removeFlaggedPost on confirm', async () => { + test('should call Client4.removeFlaggedPost via download flow on Remove permanently', async () => { renderWithContext( { />, ); - const confirmButton = screen.getByRole('button', {name: 'Remove message'}); - await userEvent.click(confirmButton); + // Form step with checkbox checked → click Continue triggers report fetch + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(Client4.generateFlaggedPostReport).toHaveBeenCalledWith( + flaggedPost.id, + '', + 'remove', + expect.any(AbortSignal), + ); + }); + + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + + await userEvent.click(screen.getByRole('button', {name: 'Remove permanently'})); await waitFor(() => { expect(Client4.removeFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, ''); }); expect(onExited).toHaveBeenCalled(); }); + + test('should go through skip-confirm when checkbox is unchecked', async () => { + renderWithContext( + , + ); + + await userEvent.click(screen.getByTestId('download-report-checkbox')); + + // Button label changes to "Remove message" when checkbox unchecked + await userEvent.click(screen.getByRole('button', {name: 'Remove message'})); + + await waitFor(() => { + expect(screen.getByTestId('skip-confirm-body')).toBeVisible(); + }); + + await userEvent.click(screen.getByRole('button', {name: 'Remove without report'})); + + await waitFor(() => { + expect(Client4.removeFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, ''); + }); + expect(Client4.generateFlaggedPostReport).not.toHaveBeenCalled(); + }); + + test('should show error step when report generation fails and allow retry', async () => { + mockReportFailure(); + + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(screen.getByTestId('error-section')).toBeVisible(); + }); + + // Switch to success and retry + mockReportSuccess(); + await userEvent.click(screen.getByTestId('error-retry-button')); + + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + }); }); describe('keep action', () => { @@ -150,7 +244,7 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { ); expect(screen.getByTestId('keep-remove-flagged-message-confirmation-modal')).toBeVisible(); - expect(screen.getByRole('button', {name: 'Keep message'})).toBeVisible(); + expect(screen.getByRole('button', {name: 'Continue'})).toBeVisible(); }); test('should show notification subtext when notify_reporter_on_dismissal is true', () => { @@ -188,7 +282,7 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { expect(subtext).toHaveTextContent(/the message will be visible to all channel members./); }); - test('should call Client4.keepFlaggedPost on confirm', async () => { + test('should call Client4.keepFlaggedPost via download flow on Keep permanently', async () => { renderWithContext( { />, ); - const confirmButton = screen.getByRole('button', {name: 'Keep message'}); - await userEvent.click(confirmButton); + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + + await userEvent.click(screen.getByRole('button', {name: 'Keep permanently'})); await waitFor(() => { expect(Client4.keepFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, ''); }); expect(onExited).toHaveBeenCalled(); }); + + test('should call Client4.keepFlaggedPost directly without skip-confirm when checkbox is unchecked', async () => { + renderWithContext( + , + ); + + await userEvent.click(screen.getByTestId('download-report-checkbox')); + await userEvent.click(screen.getByRole('button', {name: 'Keep message'})); + + await waitFor(() => { + expect(Client4.keepFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, ''); + }); + expect(screen.queryByTestId('skip-confirm-body')).not.toBeInTheDocument(); + expect(Client4.generateFlaggedPostReport).not.toHaveBeenCalled(); + }); + + test('skip from generating step calls keepFlaggedPost directly without skip-confirm', async () => { + Client4.generateFlaggedPostReport = jest.fn().mockReturnValue(new Promise(() => {})); + + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(screen.getByTestId('generating-section')).toBeVisible(); + }); + + await userEvent.click(screen.getByTestId('generating-skip-button')); + + await waitFor(() => { + expect(Client4.keepFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, ''); + }); + expect(screen.queryByTestId('skip-confirm-body')).not.toBeInTheDocument(); + }); }); describe('comment section', () => { @@ -259,17 +404,119 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { />, ); - const confirmButton = screen.getByRole('button', {name: 'Remove message'}); - await userEvent.click(confirmButton); + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); await waitFor(() => { expect(screen.getByText('Please add a comment.')).toBeVisible(); }); + expect(Client4.generateFlaggedPostReport).not.toHaveBeenCalled(); expect(Client4.removeFlaggedPost).not.toHaveBeenCalled(); expect(onExited).not.toHaveBeenCalled(); }); }); + describe('step transitions', () => { + test('should pass typed comment to action API', async () => { + renderWithContext( + , + ); + + await userEvent.type(screen.getByPlaceholderText('Add your comment here'), 'looks fine'); + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + await userEvent.click(screen.getByRole('button', {name: 'Remove permanently'})); + + await waitFor(() => { + expect(Client4.removeFlaggedPost).toHaveBeenCalledWith(flaggedPost.id, 'looks fine'); + }); + }); + + test('clicking "Download again" on generated step retriggers report fetch', async () => { + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + const initialCallCount = (Client4.generateFlaggedPostReport as jest.Mock).mock.calls.length; + + await userEvent.click(screen.getByTestId('generated-download-again-button')); + + await waitFor(() => { + expect((Client4.generateFlaggedPostReport as jest.Mock).mock.calls.length).toBeGreaterThan(initialCallCount); + }); + await waitFor(() => { + expect(screen.getByTestId('generated-section')).toBeVisible(); + }); + }); + + test('skip from generating step routes to skip-confirm', async () => { + // Hold the request open so we can interact with the generating footer + Client4.generateFlaggedPostReport = jest.fn().mockReturnValue(new Promise(() => {})); + + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button', {name: 'Continue'})); + + await waitFor(() => { + expect(screen.getByTestId('generating-section')).toBeVisible(); + }); + + await userEvent.click(screen.getByTestId('generating-skip-button')); + + await waitFor(() => { + expect(screen.getByTestId('skip-confirm-body')).toBeVisible(); + }); + expect(screen.queryByTestId('generated-section')).not.toBeInTheDocument(); + }); + + test('back from skip-confirm returns to form step', async () => { + renderWithContext( + , + ); + + await userEvent.click(screen.getByTestId('download-report-checkbox')); + await userEvent.click(screen.getByRole('button', {name: 'Remove message'})); + + await waitFor(() => { + expect(screen.getByTestId('skip-confirm-body')).toBeVisible(); + }); + + await userEvent.click(screen.getByTestId('skip-confirm-back-button')); + + await waitFor(() => { + expect(screen.getByRole('button', {name: 'Remove message'})).toBeVisible(); + }); + }); + }); + describe('error handling', () => { test('should show request error when API call fails', async () => { const errorMessage = 'Failed to remove flagged post'; @@ -284,8 +531,15 @@ describe('KeepRemoveFlaggedMessageConfirmationModal', () => { />, ); - const confirmButton = screen.getByRole('button', {name: 'Remove message'}); - await userEvent.click(confirmButton); + // Skip download path so we go directly to API call + await userEvent.click(screen.getByTestId('download-report-checkbox')); + await userEvent.click(screen.getByRole('button', {name: 'Remove message'})); + + await waitFor(() => { + expect(screen.getByTestId('skip-confirm-body')).toBeVisible(); + }); + + await userEvent.click(screen.getByRole('button', {name: 'Remove without report'})); await waitFor(() => { const errorElement = screen.getByTestId( diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.tsx index c633a83c529..e746f2fb1aa 100644 --- a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.tsx +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/remove_flagged_message_confirmation_modal.tsx @@ -1,7 +1,7 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import React, {useCallback} from 'react'; +import React, {useCallback, useEffect, useRef, useState} from 'react'; import {useIntl} from 'react-intl'; import {GenericModal} from '@mattermost/components'; @@ -11,17 +11,23 @@ import type {UserProfile} from '@mattermost/types/users'; import {Client4} from 'mattermost-redux/client'; -import AtMention from 'components/at_mention'; import {useChannel} from 'components/common/hooks/useChannel'; import {useContentFlaggingConfig} from 'components/common/hooks/useContentFlaggingFields'; -import {useUser} from 'components/common/hooks/useUser'; import type {TextboxElement} from 'components/textbox'; -import AdvancedTextbox from 'components/widgets/advanced_textbox/advanced_textbox'; + +import ErrorStepBody from './error_step/error_step_body'; +import ErrorStepFooter from './error_step/error_step_footer'; +import {FormStepBody} from './form_step/form_step_body'; +import {FormStepFooter} from './form_step/form_step_footer'; +import GeneratedStepBody from './generated_step/generated_step_body'; +import GeneratedStepFooter from './generated_step/generated_step_footer'; +import {GeneratingStepBody} from './generating_step/generating_step_body'; +import {GeneratingStepFooter} from './generating_step/generating_step_footer'; +import {SkipConfirmStepBody} from './skip_confirm_step/skip_confirm_step_body'; +import {SkipConfirmStepFooter} from './skip_confirm_step/skip_confirm_step_footer'; import './remove_flagged_message_confirmation_modal.scss'; -const noop = () => {}; - type Props = { action: 'keep' | 'remove'; onExited: () => void; @@ -29,18 +35,28 @@ type Props = { reportingUser: UserProfile; } +type Step = 'form' | 'skip_confirm' | 'generating' | 'generated' | 'error'; + export default function KeepRemoveFlaggedMessageConfirmationModal({action, onExited, flaggedPost, reportingUser}: Props) { const {formatMessage} = useIntl(); - const flaggedPostAuthor = useUser(flaggedPost.user_id); const flaggedPostChannel = useChannel(flaggedPost.channel_id); const contentFlaggingConfig = useContentFlaggingConfig(flaggedPostChannel?.team_id || ''); - const [comment, setComment] = React.useState(''); - const [commentError, setCommentError] = React.useState(''); - const [requestError, setRequestError] = React.useState(''); - const [submitting, setSubmitting] = React.useState(false); - const [showCommentPreview, setShowCommentPreview] = React.useState(false); + const [comment, setComment] = useState(''); + const [commentError, setCommentError] = useState(''); + const [requestError, setRequestError] = useState(''); + const [submitting, setSubmitting] = useState(false); + const [showCommentPreview, setShowCommentPreview] = useState(false); + const [downloadReport, setDownloadReport] = useState(true); + const [step, setStep] = useState('form'); + + const abortControllerRef = useRef(null); + + const handleClose = useCallback(() => { + abortControllerRef.current?.abort(); + onExited(); + }, [onExited]); const handleCommentChange = useCallback((e: React.ChangeEvent) => { setComment(e.target.value); @@ -56,104 +72,26 @@ export default function KeepRemoveFlaggedMessageConfirmationModal({action, onExi setShowCommentPreview((prev) => !prev); }, []); - const removeActionLabel = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.title', defaultMessage: 'Remove message from channel'}); - const keepActionLabel = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.title', defaultMessage: 'Keep message'}); - - const removeActionBody = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_remove.body', - defaultMessage: 'You are about to remove a message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.', - }, { - flaggedPostChannel: flaggedPostChannel?.display_name, - reportingUser: , - flaggedPostAuthor: , - }); - const keepActionBody = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_keep.body', - defaultMessage: 'You are about to keep a quarantined message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.', - }, { - flaggedPostChannel: flaggedPostChannel?.display_name, - reportingUser: , - flaggedPostAuthor: , - }); - - const removeActionBodySubTextReporterNotification = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_remove.subtext.notify_reporter', - defaultMessage: 'If you confirm, the message will be removed from the channel and a notification will be sent to the reporter. This action cannot be reverted.', - }); - const removeActionBodySubTextNoReporterNotification = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_remove.subtext.no_notify_reporter', - defaultMessage: 'If you confirm, the message will be removed from the channel. This action cannot be reverted.', - }); - - const keepActionBodySubTextReporterNotification = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_keep.subtext.notify_reporter', - defaultMessage: 'If you confirm, the message will be visible to all channel members and a notification will be sent to the reporter.', - }); - const keepActionBodySubTextNoReporterNotification = formatMessage({ - id: 'keep_remove_quarantined_content_modal.action_keep.subtext.no_notify_reporter', - defaultMessage: 'If you confirm, the message will be visible to all channel members.', - }); - - const requiredCommentSectionTitle = formatMessage({id: 'remove_flag_post_confirm_modal.required_comment.title', defaultMessage: 'Comment (required)'}); - const optionalCommentSectionTitle = formatMessage({id: 'remove_flag_post_confirm_modal.optional_comment.title', defaultMessage: 'Comment (optional)'}); - - const commentPlaceholder = formatMessage({id: 'keep_remove_quarantined_content_modal.comment.placeholder', defaultMessage: 'Add your comment here'}); - const removeMessageButtonText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.button_text', defaultMessage: 'Remove message'}); - const keepMessageButtonText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.button_text', defaultMessage: 'Keep message'}); - - let label; - let subtext; - let body; - let buttonText; - let confirmButtonVariant; - - if (action === 'remove') { - label = removeActionLabel; - body = removeActionBody; - buttonText = removeMessageButtonText; - confirmButtonVariant = 'destructive' as const; - - if (contentFlaggingConfig?.notify_reporter_on_removal) { - subtext = removeActionBodySubTextReporterNotification; - } else { - subtext = removeActionBodySubTextNoReporterNotification; - } - } else { - label = keepActionLabel; - body = keepActionBody; - buttonText = keepMessageButtonText; - - if (contentFlaggingConfig?.notify_reporter_on_dismissal) { - subtext = keepActionBodySubTextReporterNotification; - } else { - subtext = keepActionBodySubTextNoReporterNotification; - } - } + const handleToggleDownloadReport = useCallback((e: React.ChangeEvent) => { + setDownloadReport(e.target.checked); + }, []); const validateForm = useCallback(() => { - let hasErrors = false; - if (contentFlaggingConfig?.reviewer_comment_required && comment.trim() === '') { setCommentError(formatMessage({id: 'keep_remove_quarantined_content_modal.comment_required.error', defaultMessage: 'Please add a comment.'})); - hasErrors = true; - } else { - setCommentError(''); + return true; } - - return hasErrors; + setCommentError(''); + return false; }, [comment, contentFlaggingConfig?.reviewer_comment_required, formatMessage]); - const handleConfirm = useCallback(async () => { - const hasError = validateForm(); - if (hasError) { - return; - } - + const callActionAPI = useCallback(async () => { const actionFunc = action === 'remove' ? Client4.removeFlaggedPost : Client4.keepFlaggedPost; try { setSubmitting(true); + setRequestError(''); await actionFunc(flaggedPost.id, comment); - onExited(); + handleClose(); } catch (error) { // eslint-disable-next-line no-console console.error(error); @@ -161,7 +99,181 @@ export default function KeepRemoveFlaggedMessageConfirmationModal({action, onExi } finally { setSubmitting(false); } - }, [action, comment, flaggedPost.id, onExited, validateForm]); + }, [action, comment, flaggedPost.id, handleClose]); + + const handleFormPrimary = useCallback(() => { + if (validateForm()) { + return; + } + setRequestError(''); + if (downloadReport) { + setStep('generating'); + } else if (action === 'keep') { + callActionAPI(); + } else { + setStep('skip_confirm'); + } + }, [validateForm, downloadReport, action, callActionAPI]); + + const handleSkipConfirmBack = useCallback(() => { + setRequestError(''); + setStep('form'); + }, []); + + const handleSkipFromGenerating = useCallback(() => { + abortControllerRef.current?.abort(); + setRequestError(''); + if (action === 'keep') { + callActionAPI(); + } else { + setStep('skip_confirm'); + } + }, [action, callActionAPI]); + + const handleBackToForm = useCallback(() => { + abortControllerRef.current?.abort(); + setRequestError(''); + setStep('form'); + }, []); + + const handleRetryGeneration = useCallback(() => { + setRequestError(''); + setStep('generating'); + }, []); + + useEffect(() => { + if (step !== 'generating') { + return undefined; + } + + const controller = new AbortController(); + abortControllerRef.current = controller; + + (async () => { + try { + const blob = await Client4.generateFlaggedPostReport(flaggedPost.id, comment, action, controller.signal); + if (controller.signal.aborted) { + return; + } + + const downloadUrl = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = downloadUrl; + a.download = `flagged-post-${flaggedPost.id}-${Date.now()}.zip`; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(downloadUrl); + + setStep('generated'); + } catch (err) { + if (controller.signal.aborted) { + return; + } + + // eslint-disable-next-line no-console + console.error(err); + setStep('error'); + } + })(); + + return () => { + controller.abort(); + }; + }, [step, flaggedPost.id, comment, action]); + + const removeLabel = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove.title', defaultMessage: 'Remove message from channel'}); + const keepLabel = formatMessage({id: 'keep_remove_quarantined_content_modal.action_keep.title', defaultMessage: 'Keep message'}); + + const removeWithoutReportLabel = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove_without_report.title', defaultMessage: 'Remove without report?'}); + + const bodyContentProps = { + action, + flaggedPost, + reportingUser, + contentFlaggingConfig, + }; + + let label = action === 'remove' ? removeLabel : keepLabel; + let modalBody: React.ReactNode = null; + let footer: React.ReactNode = null; + + switch (step) { + case 'form': + modalBody = ( + + ); + footer = ( + + ); + break; + case 'skip_confirm': + label = removeWithoutReportLabel; + modalBody = ( + ); + footer = ( + + ); + break; + case 'generating': + modalBody = ; + footer = ( + + ); + break; + case 'generated': + modalBody = ; + footer = ( + + ); + break; + case 'error': + modalBody = ( + + ); + footer = ( + + ); + break; + } return (
-
- {body} -
-
- {subtext} -
- -
-
- {contentFlaggingConfig?.reviewer_comment_required ? requiredCommentSectionTitle : optionalCommentSectionTitle} -
- - {}} - hasError={false} - errorMessage={commentError} - maxLength={1000} - /> -
- {requestError && + {modalBody} + {requestError && (
{requestError}
- } + )}
); diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.scss new file mode 100644 index 00000000000..e888ae03d50 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.scss @@ -0,0 +1,76 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal { + .ReportNotice { + display: flex; + align-items: flex-start; + padding: 16px; + border: 1px solid; + border-radius: 4px; + gap: 12px; + + &__icon { + display: flex; + width: 20px; + height: 20px; + flex-shrink: 0; + align-items: center; + justify-content: center; + + .icon { + font-size: 20px; + line-height: 1; + } + } + + &__body { + display: flex; + min-width: 0; + flex: 1 1 0; + flex-direction: column; + gap: 8px; + } + + &__title { + color: var(--center-channel-color); + font-size: 14px; + font-weight: 600; + line-height: 20px; + } + + &__text { + color: var(--center-channel-color); + font-size: 14px; + font-weight: 400; + line-height: 20px; + } + + &--info { + border-color: rgba(var(--sidebar-text-active-border-rgb), 0.16); + background: rgba(var(--button-bg-rgb), 0.04); + + .ReportNotice__icon { + color: var(--button-bg); + } + } + + &--success { + border-color: rgba(var(--online-indicator-rgb), 0.16); + background: rgba(var(--online-indicator-rgb), 0.08); + + .ReportNotice__icon { + color: var(--online-indicator); + } + } + + &--warning { + border-color: rgba(var(--away-indicator-rgb), 0.16); + background: rgba(var(--away-indicator-rgb), 0.08); + + .ReportNotice__icon { + color: var(--away-indicator); + } + } + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.tsx new file mode 100644 index 00000000000..97617dfe6e7 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/report_notice.tsx @@ -0,0 +1,30 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import classNames from 'classnames'; +import React from 'react'; + +import './report_notice.scss'; + +type Props = { + variant: 'info' | 'success' | 'warning'; + icon: React.ReactNode; + title: React.ReactNode; + body: React.ReactNode; + testId?: string; +}; + +export default function ReportNotice({variant, icon, title, body, testId}: Props) { + return ( +
+
{icon}
+
+
{title}
+
{body}
+
+
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_body.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_body.tsx new file mode 100644 index 00000000000..0684b0c3612 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_body.tsx @@ -0,0 +1,42 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {useIntl} from 'react-intl'; + +import type {Post} from '@mattermost/types/posts'; +import type {UserProfile} from '@mattermost/types/users'; + +import BodyMainActionText from 'components/remove_flagged_message_confirmation_modal/body_main_action_text'; + +type BodyProps = { + flaggedPost: Post; + reportingUser: UserProfile; +}; + +export function SkipConfirmStepBody({ + flaggedPost, + reportingUser, +}: BodyProps) { + const {formatMessage} = useIntl(); + + const text = formatMessage({ + id: 'keep_remove_quarantined_content_modal.action_remove.skip_confirm.body', + defaultMessage: + 'You are proceeding with content removal without downloading a report. Any subsequently generated report will not contain the original message contents. This action cannot be reverted.', + }); + + return ( +
+ + {text} +
+ ); +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.scss b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.scss new file mode 100644 index 00000000000..c6c10103996 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.scss @@ -0,0 +1,8 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +.KeepRemoveFlaggedMessageConfirmationModal { + .ModalFooterRow.ModalFooterRow--end { + justify-content: flex-end; + } +} diff --git a/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.tsx b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.tsx new file mode 100644 index 00000000000..651ec3e5971 --- /dev/null +++ b/webapp/channels/src/components/remove_flagged_message_confirmation_modal/skip_confirm_step/skip_confirm_step_footer.tsx @@ -0,0 +1,44 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {useIntl} from 'react-intl'; + +import './skip_confirm_step_footer.scss'; + +type FooterProps = { + submitting: boolean; + onBack: () => void; + onPrimary: () => void; +}; + +export function SkipConfirmStepFooter({submitting, onBack, onPrimary}: FooterProps) { + const {formatMessage} = useIntl(); + + const backText = formatMessage({id: 'keep_remove_quarantined_content_modal.back.button_text', defaultMessage: 'Back'}); + const primaryText = formatMessage({id: 'keep_remove_quarantined_content_modal.action_remove_without_report.button_text', defaultMessage: 'Remove without report'}); + + return ( +
+
+ + +
+
+ ); +} diff --git a/webapp/channels/src/i18n/en.json b/webapp/channels/src/i18n/en.json index 3b418757e21..65f3402ac17 100644 --- a/webapp/channels/src/i18n/en.json +++ b/webapp/channels/src/i18n/en.json @@ -4458,8 +4458,13 @@ "custom_status.suggestions.working_from_home": "Working from home", "data_spillage_report_post.reporting_comment.placeholder": "No comment", "data_spillage_report_post.title": "{user} submitted a message for review", + "data_spillage_report.download_report.button_text": "Download Report", + "data_spillage_report.download_report.failed.button_text": "Generation failed. Try again.", + "data_spillage_report.download_report.generating.button_text": "Generating report…", "data_spillage_report.keep_message.button_text": "Keep message", "data_spillage_report.remove_message.button_text": "Remove message", + "data_spillage_report.row.actions.label": "Actions", + "data_spillage_report.row.report.label": "Report", "data_spillage_report.view_details.button_text": "View details", "date_separator.today": "Today", "date_separator.tomorrow": "Tomorrow", @@ -5267,16 +5272,35 @@ "katex.error": "Couldn't compile your Latex code. Please review the syntax and try again.", "keep_remove_quarantined_content_modal.action_keep.body": "You are about to keep a quarantined message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.", "keep_remove_quarantined_content_modal.action_keep.button_text": "Keep message", + "keep_remove_quarantined_content_modal.action_keep.generated.body": "The report should now be downloading on your device. Once it is downloaded, you can keep the message permanently.", + "keep_remove_quarantined_content_modal.action_keep.generating.body": "Please wait for the report to download before you keep the message permanently.", + "keep_remove_quarantined_content_modal.action_keep.permanent_button_text": "Keep permanently", "keep_remove_quarantined_content_modal.action_keep.subtext.no_notify_reporter": "If you confirm, the message will be visible to all channel members.", "keep_remove_quarantined_content_modal.action_keep.subtext.notify_reporter": "If you confirm, the message will be visible to all channel members and a notification will be sent to the reporter.", "keep_remove_quarantined_content_modal.action_keep.title": "Keep message", + "keep_remove_quarantined_content_modal.action_remove_without_report.button_text": "Remove without report", + "keep_remove_quarantined_content_modal.action_remove_without_report.title": "Remove without report?", "keep_remove_quarantined_content_modal.action_remove.body": "You are about to remove a message authored by {flaggedPostAuthor} posted in the {flaggedPostChannel} channel and quarantined for review by {reportingUser}.", "keep_remove_quarantined_content_modal.action_remove.button_text": "Remove message", + "keep_remove_quarantined_content_modal.action_remove.generated.body": "The report should now be downloading on your device. Once it is downloaded, you can remove the message permanently.", + "keep_remove_quarantined_content_modal.action_remove.generating.body": "Please wait for the report to download before you remove the message permanently. There will be no way to recover the message contents once it is removed.", + "keep_remove_quarantined_content_modal.action_remove.permanent_button_text": "Remove permanently", + "keep_remove_quarantined_content_modal.action_remove.skip_confirm.body": "You are proceeding with content removal without downloading a report. Any subsequently generated report will not contain the original message contents. This action cannot be reverted.", "keep_remove_quarantined_content_modal.action_remove.subtext.no_notify_reporter": "If you confirm, the message will be removed from the channel. This action cannot be reverted.", "keep_remove_quarantined_content_modal.action_remove.subtext.notify_reporter": "If you confirm, the message will be removed from the channel and a notification will be sent to the reporter. This action cannot be reverted.", "keep_remove_quarantined_content_modal.action_remove.title": "Remove message from channel", + "keep_remove_quarantined_content_modal.back.button_text": "Back", "keep_remove_quarantined_content_modal.comment_required.error": "Please add a comment.", "keep_remove_quarantined_content_modal.comment.placeholder": "Add your comment here", + "keep_remove_quarantined_content_modal.continue.button_text": "Continue", + "keep_remove_quarantined_content_modal.download_again.button_text": "Download again", + "keep_remove_quarantined_content_modal.download_report_checkbox.label": "Download quarantined message report", + "keep_remove_quarantined_content_modal.error.body": "We were unable to generate and download the report to your device.", + "keep_remove_quarantined_content_modal.error.title": "Report could not be generated", + "keep_remove_quarantined_content_modal.generated.title": "Report generated", + "keep_remove_quarantined_content_modal.generating.title": "Generating report…", + "keep_remove_quarantined_content_modal.skip_report_download.button_text": "Skip report download", + "keep_remove_quarantined_content_modal.try_again.button_text": "Try again", "last_users_message.added_to_channel.type": "were **added to the channel** by {actor}.", "last_users_message.added_to_team.type": "were **added to the team** by {actor}.", "last_users_message.first": "{firstUser} and ", @@ -5903,7 +5927,6 @@ "promote_to_user_modal.desc": "This action promotes the guest {username} to a member. It will allow the user to join public channels and interact with users outside of the channels they are currently members of. Are you sure you want to promote guest {username} to member?", "promote_to_user_modal.promote": "Promote", "promote_to_user_modal.title": "Promote guest {username} to member", - "property_card.actions_row.label": "Actions", "property_card.field.action_time.label": "Reviewed at", "property_card.field.actor_comment.label": "Reviewer's comment", "property_card.field.actor_user_id.label": "Reviewed by", diff --git a/webapp/platform/client/src/client4.test.ts b/webapp/platform/client/src/client4.test.ts index 8aa37520ee5..5f6a2174537 100644 --- a/webapp/platform/client/src/client4.test.ts +++ b/webapp/platform/client/src/client4.test.ts @@ -255,6 +255,25 @@ describe('Client4', () => { expect(result[1]).toEqual({user_id: 'dummy-user-id', channel_id: 'channel2', roles: 'channel_user channel_admin'}); expect(result[2]).toEqual({user_id: 'dummy-user-id', channel_id: 'channel3', roles: 'channel_user'}); }); + + test('should parse ZIP responses as blobs', async () => { + const client = new Client4(); + client.setUrl('http://mattermost.example.com'); + + const postId = 'dummy-post-id'; + const zipData = Buffer.from('zip contents'); + + nock(client.getBaseRoute()). + post(`/content_flagging/post/${postId}/report`, {comment: 'investigation note'}). + reply(200, zipData, {'Content-Type': 'application/zip'}); + + const result = await client.generateFlaggedPostReport(postId, 'investigation note'); + + expect(typeof result.text).toBe('function'); + expect(result.size).toEqual(zipData.length); + expect(result.type).toEqual('application/zip'); + expect(await result.text()).toEqual('zip contents'); + }); }); }); diff --git a/webapp/platform/client/src/client4.ts b/webapp/platform/client/src/client4.ts index 3ca3afe26e6..1c7d16a424a 100644 --- a/webapp/platform/client/src/client4.ts +++ b/webapp/platform/client/src/client4.ts @@ -4631,6 +4631,8 @@ export default class Client4 { const text = await response.text(); const objects = text.trim().split('\n'); data = objects.map((obj) => JSON.parse(obj)); + } else if (contentType === 'application/zip') { + data = await response.blob(); } else { data = await response.text(); } @@ -5083,6 +5085,21 @@ export default class Client4 { {method: 'get'}, ); }; + + getFlaggedPostReportUrl = (postId: string) => { + return `${this.getContentFlaggingRoute()}/post/${postId}/report`; + }; + + generateFlaggedPostReport = (postId: string, comment: string, action?: 'keep' | 'remove', signal?: AbortSignal): Promise => { + return this.doFetch( + this.getFlaggedPostReportUrl(postId), + { + method: 'post', + body: JSON.stringify({comment, action}), + signal, + }, + ); + }; } export function parseAndMergeNestedHeaders(originalHeaders: any) { From d4471bece166ea55aa605e750cfa1ac4a9580eb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20V=C3=A9lez?= Date: Mon, 18 May 2026 17:33:13 +0200 Subject: [PATCH 24/80] Mm 68503 be abac mask save path masking (#36513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MM-68501 - implement GetMaskedVisualAST and wire API handler Co-Authored-By: Claude Opus 4.6 (1M context) * add missing test and fix style issues * fix styles * implement coderabbit feedback * MM-68501 - PR review: split masking file, model-level access mode, reject contradictory config Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68501 - apply shared_only filter to non-option field values (binary masking) * MM-68501 - consolidate masking flag check and log corrupt text value during masking * MM-68503 - add CEL utilities, write-path validation, and merge helpers Combined set of helpers consumed by BE-5's save path: CEL construction / serialization - extractStringValues, buildCELFromConditions, conditionToCEL, celStringLiteral, celValueLiteral. Used to rebuild a CEL string from a VisualExpression, including for GetMaskedExpression on the read-side of policy GET / search responses. Merge-on-save helpers - getHiddenValues (per-condition, with pre-fetched fields map for N+1 avoidance) — finds which stored values are not visible to the caller. - mergeConditionValues — re-injects the hidden values into a submitted condition without duplicates. - Together, these let BE-5 preserve attribute values the caller cannot see while still letting them edit the visible parts of a policy. Write-path value-hold validation - validatePolicyExpressionValues, invalidValueError, validateConditionValues. - Generic "Invalid value." error on every rejection — no signal about whether the value exists or is merely not held (prevents enumeration). - Rejects the masked-token sentinel "--------" if submitted as a literal. These all live in access_control_masking.go alongside the masking primitives that BE-2 introduced. i18n entries added for the two new error IDs (app.pap.save_policy.invalid_value, app.pap.validate_expression_values.app_error). Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68503 - handle the masked-token sentinel in validation and merge When the GET /policies endpoint returns a policy via MaskPolicyExpressions, the raw expression contains the masked-token sentinel "--------" in place of hidden values. If the frontend round-trips that expression unchanged back to the server (e.g., the admin only modified channel assignment, not the rules), the sentinel reaches the save path. The previous code in validateConditionValues rejected the sentinel as "Invalid value." This blocks the legitimate round-trip case. Fix: - validateConditionValues: treat the sentinel as a placeholder and skip it during visibility / source-only / unknown-mode checks. Other values are still validated normally. - mergeConditionValues: strip the sentinel from submitted values before appending hidden values, so it never propagates to the stored result. Both array and single-value forms (string == "--------") are handled. TestMaskedTokenRejection (which asserted the old rejection behavior) is replaced by TestMaskedTokenConstant which only verifies the sentinel string itself. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68504 - integrate save-path masking: 403 block on delete, merge-on-save, response masking Save path (CreateOrUpdateAccessControlPolicy): * validatePolicyExpressionValues runs on the submitted expression before merge so re-injected hidden values are never validated against the caller's holdings. * mergeStoredPolicyExpressions re-injects hidden values from the stored policy and blocks (HTTP 403) any attempt to remove a condition that contained values the caller cannot see — closes the row-deletion gap in classified environments. * mergeExpressionWithMaskedValues unwraps single-element arrays for scalar operators after restoring the stored operator (avoids "attr == [val]" invalid CEL when the frontend submits "attr in []" as the masked-row placeholder for an originally-scalar condition). * checkSelfInclusion is bypassed for system admins (they may legitimately write conditions for values they do not hold); masking and value-hold validation still apply to system admins. Delete path (DeleteAccessControlPolicy): * Same masked-values 403 block — a caller with masked values cannot delete the policy at all (UI Delete button is also disabled in FE-3). Response masking: * createAccessControlPolicy and setAccessControlPolicyActiveStatus run MaskPolicyExpressions on the response so even a save reply doesn't leak the values the caller does not hold. GetMaskedExpression, maskConditionValuesWithToken, replaceHiddenValuesWithToken, MaskPolicyExpressions live alongside the rest of the masking helpers in access_control_masking.go. team_access_control.go: corrects ValidateChannelEligibilityForAccessControl call site (drops the spurious receiver and rctx; it's a package-level helper that only takes channel). Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68503 - address PR review: batch field fetches, propagate errors, fail-closed write path * MM-68503 - restore team-admin api4 tests accidentally dropped during BE-5 rebuild * MM-68503 - address review and CodeRabbit feedback on save-path masking * add tests for delete masking, self-inclusion, GET mask * add assertions to strengten tests * fail-closed guard for advanced expressions in merge-on-save, plus helper unit tests, and FF/test-helper cleanups * Refactor access control methods to use GetPropertyGroup for CPA group ID retrieval --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Mattermost Build --- server/channels/api4/access_control.go | 22 +- server/channels/api4/access_control_test.go | 144 ++-- server/channels/app/access_control.go | 317 ++++++++- server/channels/app/access_control_masking.go | 629 +++++++++++++++++- .../channels/app/access_control_merge_test.go | 474 +++++++++++++ server/channels/app/access_control_test.go | 150 +++++ .../app/access_control_validation_test.go | 231 +++++++ server/i18n/en.json | 32 + 8 files changed, 1944 insertions(+), 55 deletions(-) create mode 100644 server/channels/app/access_control_merge_test.go create mode 100644 server/channels/app/access_control_validation_test.go diff --git a/server/channels/api4/access_control.go b/server/channels/api4/access_control.go index c8bb59038b2..af0c87b23ac 100644 --- a/server/channels/api4/access_control.go +++ b/server/channels/api4/access_control.go @@ -15,7 +15,7 @@ import ( ) // shouldRedactExpressions reports whether raw CEL expressions should be masked for this caller. -// Returns true when both ABAC and attribute-value masking are enabled. Callers reading raw expressions +// Masking is attribute-based, not permission-based: system admins who do not hold all values // in a policy must also receive redacted raw expressions. func shouldRedactExpressions(c *Context) bool { return c.App.Config().FeatureFlags.AttributeBasedAccessControl && @@ -141,6 +141,10 @@ func createAccessControlPolicy(c *Context, w http.ResponseWriter, r *http.Reques auditRec.AddEventObjectType("access_control_policy") auditRec.AddEventResultState(np) + if shouldRedactExpressions(c) { + c.App.MaskPolicyExpressions(c.AppContext, np, c.AppContext.Session().UserId) + } + js, err := json.Marshal(np) if err != nil { c.Err = model.NewAppError("createAccessControlPolicy", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err) @@ -193,6 +197,10 @@ func getAccessControlPolicy(c *Context, w http.ResponseWriter, r *http.Request) return } + if shouldRedactExpressions(c) { + c.App.MaskPolicyExpressions(c.AppContext, policy, c.AppContext.Session().UserId) + } + js, err := json.Marshal(policy) if err != nil { c.Err = model.NewAppError("getAccessControlPolicy", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err) @@ -496,6 +504,12 @@ func searchAccessControlPolicies(c *Context, w http.ResponseWriter, r *http.Requ policies = filtered } + if shouldRedactExpressions(c) { + for _, p := range policies { + c.App.MaskPolicyExpressions(c.AppContext, p, c.AppContext.Session().UserId) + } + } + result := model.AccessControlPoliciesWithCount{ Policies: policies, Total: total, @@ -628,6 +642,12 @@ func setActiveStatus(c *Context, w http.ResponseWriter, r *http.Request) { } auditRec.Success() + if shouldRedactExpressions(c) { + for _, p := range policies { + c.App.MaskPolicyExpressions(c.AppContext, p, c.AppContext.Session().UserId) + } + } + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(policies); err != nil { c.Logger.Warn("Error while writing response", mlog.Err(err)) diff --git a/server/channels/api4/access_control_test.go b/server/channels/api4/access_control_test.go index 8b09d036ce2..0a4d5251666 100644 --- a/server/channels/api4/access_control_test.go +++ b/server/channels/api4/access_control_test.go @@ -474,15 +474,22 @@ func TestDeleteAccessControlPolicy(t *testing.T) { mockAccessControlService := &mocks.AccessControlServiceInterface{} th.App.Srv().Channels().AccessControl = mockAccessControlService - // DeleteAccessControlPolicy resolves the policy first to decide - // whether to broadcast a channel access control update; return a - // parent policy so the channel-update path is not exercised here. - parentPolicy := &model.AccessControlPolicy{ - ID: samplePolicyID, - Type: model.AccessControlPolicyTypeParent, - Version: model.AccessControlPolicyVersionV0_3, + + // DeleteAccessControlPolicy resolves the policy first to decide whether + // to broadcast a channel access-control update after deletion. + channelPolicy := &model.AccessControlPolicy{ + ID: samplePolicyID, + Type: model.AccessControlPolicyTypeChannel, + Version: model.AccessControlPolicyVersionV0_3, + Revision: 1, + Rules: []model.AccessControlPolicyRule{ + { + Expression: "user.attributes.team == 'engineering'", + Actions: []string{"membership"}, + }, + }, } - mockAccessControlService.On("GetPolicy", mock.AnythingOfType("*request.Context"), samplePolicyID).Return(parentPolicy, nil).Times(1) + mockAccessControlService.On("GetPolicy", mock.AnythingOfType("*request.Context"), samplePolicyID).Return(channelPolicy, nil).Times(1) mockAccessControlService.On("DeletePolicy", mock.AnythingOfType("*request.Context"), samplePolicyID).Return(nil).Times(1) th.App.UpdateConfig(func(cfg *model.Config) { @@ -1180,38 +1187,17 @@ func TestSearchChannelsForAccessControlPolicy(t *testing.T) { t.Run("team admin body TeamIds forced to authorized team", func(t *testing.T) { setupLicenseAndABAC(t) - parentPolicy := newSamplePolicy() - savedParent, err := th.App.Srv().Store().AccessControlPolicy().Save(th.Context, parentPolicy) + policy := newSamplePolicy() + savedPolicy, err := th.App.Srv().Store().AccessControlPolicy().Save(th.Context, policy) require.NoError(t, err) defer func() { - _ = th.App.Srv().Store().AccessControlPolicy().Delete(th.Context, savedParent.ID) - }() - - // Two teams, each with one private channel. The BasicTeam channel is - // linked to the parent policy so it shows up in the search; the - // otherTeam channel is unrelated. The override-correctness test then - // proves both that the BasicTeam channel IS returned (the search - // isn't trivially empty) and that the otherTeam channel is NOT - // returned even though the request body asked for it explicitly. - basicTeamChannel := th.CreateChannelWithClientAndTeam(t, th.SystemAdminClient, model.ChannelTypePrivate, th.BasicTeam.Id) - basicTeamChild := &model.AccessControlPolicy{ - ID: basicTeamChannel.Id, - Type: model.AccessControlPolicyTypeChannel, - Version: model.AccessControlPolicyVersionV0_3, - Revision: 1, - Imports: []string{savedParent.ID}, - Rules: []model.AccessControlPolicyRule{ - {Expression: "user.attributes.team == 'engineering'", Actions: []string{"membership"}}, - }, - } - _, err = th.App.Srv().Store().AccessControlPolicy().Save(th.Context, basicTeamChild) - require.NoError(t, err) - defer func() { - _ = th.App.Srv().Store().AccessControlPolicy().Delete(th.Context, basicTeamChannel.Id) + _ = th.App.Srv().Store().AccessControlPolicy().Delete(th.Context, savedPolicy.ID) }() + // Create a second team with a private channel otherTeam := th.CreateTeam(t) otherChannel := th.CreateChannelWithClientAndTeam(t, th.SystemAdminClient, model.ChannelTypePrivate, otherTeam.Id) + _ = otherChannel th.LinkUserToTeam(t, th.TeamAdminUser, th.BasicTeam) th.UpdateUserToTeamAdmin(t, th.TeamAdminUser, th.BasicTeam) @@ -1221,26 +1207,19 @@ func TestSearchChannelsForAccessControlPolicy(t *testing.T) { // Attempt to search with body TeamIds pointing to a different team. // The authZ is against BasicTeam (via team_id query param), but the - // body tries to query otherTeam's channels. The handler should force + // body tries to query otherTeam's channels. The fix should force // TeamIds to BasicTeam.Id regardless of what the body says. channelsResp, resp, err := th.Client.SearchChannelsForAccessControlPolicyForTeam( - context.Background(), savedParent.ID, th.BasicTeam.Id, + context.Background(), savedPolicy.ID, th.BasicTeam.Id, model.ChannelSearch{TeamIds: []string{otherTeam.Id}}) require.NoError(t, err) CheckOKStatus(t, resp) require.NotNil(t, channelsResp) - channelsByID := make(map[string]*model.ChannelWithTeamData, len(channelsResp.Channels)) - for _, ch := range channelsResp.Channels { - channelsByID[ch.Id] = ch - } - require.Contains(t, channelsByID, basicTeamChannel.Id, - "BasicTeam channel must surface — proves the search is exercised, not just trivially empty") - require.NotContains(t, channelsByID, otherChannel.Id, - "otherTeam channel must NOT surface even though body asked for it — proves the team_id query param overrides body TeamIds") + // None of the returned channels should belong to the other team for _, ch := range channelsResp.Channels { require.Equal(t, th.BasicTeam.Id, ch.TeamId, - "team admin must only see channels from the authorized team, got channel %s from team %s", ch.Id, ch.TeamId) + "team admin should only see channels from the authorized team, got channel %s from team %s", ch.Id, ch.TeamId) } }) @@ -1492,6 +1471,81 @@ func newParentPolicy(teamID string) *model.AccessControlPolicy { } } +// TestResponseMaskingOnPolicyEndpoints verifies that every API endpoint returning an +// AccessControlPolicy redacts the raw CEL expression for callers who cannot see all +// values. The risk is a future endpoint forgetting to call MaskPolicyExpressions +// before serializing — the masked visual AST would still hide values, but the raw +// rule.Expression in the same response would leak them in plain text. We force the +// fail-closed branch (unknown property field) so the masking always produces the +// "--------" sentinel without requiring a real CPA setup. +func TestResponseMaskingOnPolicyEndpoints(t *testing.T) { + // SetupConfig sets FFs before route init via SetReadOnlyFF(false). Avoids + // os.Setenv which isn't parallel-safe. + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.AttributeBasedAccessControl = true + cfg.FeatureFlags.AttributeValueMasking = true + }).InitBasic(t) + + ok := th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterpriseAdvanced)) + require.True(t, ok, "SetLicense should return true") + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.AccessControlSettings.EnableAttributeBasedAccessControl = new(true) + }) + + const sensitiveExpr = `user.attributes.f_unknown_field == "TF-Zulu"` + const expectedMaskedExpr = `user.attributes.f_unknown_field == "--------"` + + // A condition referencing an unknown CPA field forces MaskPolicyExpressions + // down the fail-closed branch, which replaces the literal value with the + // masked-token sentinel. That gives us a deterministic assertion target + // without needing to seed a CPA group + protected field in this test. + unknownFieldAST := &model.VisualExpression{ + Conditions: []model.Condition{ + { + Attribute: "user.attributes.f_unknown_field", + Operator: "==", + Value: "TF-Zulu", + ValueType: model.LiteralValue, + }, + }, + } + + newPolicy := func(id string) *model.AccessControlPolicy { + return &model.AccessControlPolicy{ + ID: id, + Type: model.AccessControlPolicyTypeChannel, + Version: model.AccessControlPolicyVersionV0_3, + Revision: 1, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{"membership"}, Expression: sensitiveExpr}, + }, + } + } + + t.Run("getAccessControlPolicy response is masked", func(t *testing.T) { + // GET is the canonical read path — masking here means the raw CEL in the + // policy response cannot leak values the caller couldn't already see in the + // visual AST. The create / search / setActive paths share the same + // MaskPolicyExpressions call so they're covered by inspection. Unit-testing + // them through the HTTP handler is impractical because + // validatePolicyExpressionValues rejects unknown-field references before + // MaskPolicyExpressions ever runs, and we can't seed a real shared_only + // CPA field without plugin context. End-to-end paths are covered by E2E. + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().Channels().AccessControl = mockACS + stored := newPolicy(th.BasicChannel.Id) + mockACS.On("GetPolicy", mock.AnythingOfType("*request.Context"), stored.ID).Return(stored, nil) + mockACS.On("ExpressionToVisualAST", mock.Anything, mock.Anything).Return(unknownFieldAST, nil).Maybe() + + result, resp, err := th.SystemAdminClient.GetAccessControlPolicy(context.Background(), stored.ID) + require.NoError(t, err) + CheckOKStatus(t, resp) + require.NotEmpty(t, result.Rules) + require.Equal(t, expectedMaskedExpr, result.Rules[0].Expression, + "get response must mask the raw CEL exactly") + }) +} + func TestCreateAccessControlPolicyTeamAdmin(t *testing.T) { os.Setenv("MM_FEATUREFLAGS_ATTRIBUTEBASEDACCESSCONTROL", "true") th := Setup(t).InitBasic(t) diff --git a/server/channels/app/access_control.go b/server/channels/app/access_control.go index eb942db16ba..99df3abc372 100644 --- a/server/channels/app/access_control.go +++ b/server/channels/app/access_control.go @@ -112,6 +112,41 @@ func (a *App) CreateOrUpdateAccessControlPolicy(rctx request.CTX, policy *model. } } + // ABAC is gated at route registration; only check masking here. Masking is + // attribute-based: edits are allowed with masked values present as long as + // the caller doesn't drop a condition holding values they couldn't see. + if a.Config().FeatureFlags.AttributeValueMasking { + session := rctx.Session() + if session == nil { + return nil, model.NewAppError("CreateOrUpdateAccessControlPolicy", "api.context.session_expired.app_error", nil, "session required for masking validation", http.StatusUnauthorized) + } + callerID := session.UserId + + // Validate submitted values BEFORE merge: only the values the caller + // actually submitted should be checked against their holdings. Running + // validation after merge would reject the re-injected hidden values + // (e.g. Bravo, Charlie) that the caller legitimately cannot see. + if appErr := a.validatePolicyExpressionValues(rctx, policy, callerID); appErr != nil { + return nil, appErr + } + + // Merge hidden values back in and block deletion of masked conditions. + if appErr := a.mergeStoredPolicyExpressions(rctx, policy, callerID); appErr != nil { + return nil, appErr + } + + // Self-inclusion check applies only to non-admins. System admins may + // legitimately set conditions for attributes they do not personally hold + // (e.g., creating a "Clearance == Top Secret" rule without holding that + // clearance themselves). Masking and write-path value validation still + // apply to system admins above. + if !a.HasPermissionTo(callerID, model.PermissionManageSystem) { + if appErr := a.checkSelfInclusion(rctx, policy, callerID); appErr != nil { + return nil, appErr + } + } + } + var appErr *model.AppError policy, appErr = acs.SavePolicy(rctx, policy) if appErr != nil { @@ -128,6 +163,266 @@ func (a *App) CreateOrUpdateAccessControlPolicy(rctx request.CTX, policy *model. return policy, nil } +// policyHasMaskedValuesForCaller returns true if policy contains any attribute values +// that are not visible to callerID under the current masking rules. +// A nil policy is treated as "no hidden values" — there's nothing to protect. +func (a *App) policyHasMaskedValuesForCaller(rctx request.CTX, policy *model.AccessControlPolicy, callerID string) (bool, *model.AppError) { + if policy == nil { + return false, nil + } + + for _, rule := range policy.Rules { + if rule.Expression == "" || rule.Expression == "true" { + continue + } + maskedAST, appErr := a.GetMaskedVisualAST(rctx, rule.Expression, callerID) + if appErr != nil { + return false, appErr + } + for _, cond := range maskedAST.Conditions { + if cond.HasMaskedValues { + return true, nil + } + } + } + + return false, nil +} + +// mergeStoredPolicyExpressions re-injects hidden values from the stored policy into the +// submitted one, and blocks the save if the caller removed a condition that contained +// values they cannot see (which would silently widen access beyond what they could audit). +// No-op for new policies (not found in store). +func (a *App) mergeStoredPolicyExpressions(rctx request.CTX, policy *model.AccessControlPolicy, callerID string) *model.AppError { + acs := a.Srv().ch.AccessControl + if acs == nil { + return nil + } + + existingPolicy, appErr := acs.GetPolicy(rctx, policy.ID) + if appErr != nil { + if appErr.StatusCode == http.StatusNotFound { + return nil + } + return appErr + } + + for i, rule := range policy.Rules { + if i >= len(existingPolicy.Rules) { + continue + } + storedExpr := existingPolicy.Rules[i].Expression + if storedExpr == "" || storedExpr == "true" { + continue + } + mergedExpr, appErr := a.mergeExpressionWithMaskedValues(rctx, policy.ID, rule.Expression, storedExpr, callerID) + if appErr != nil { + return appErr + } + policy.Rules[i].Expression = mergedExpr + } + + // Any stored rules beyond the submitted set were dropped by the caller. If any of those + // contain values the caller cannot see, block the save — otherwise we'd silently widen + // access by removing a rule whose hidden conditions the caller could not audit. + if len(existingPolicy.Rules) > len(policy.Rules) { + for i := len(policy.Rules); i < len(existingPolicy.Rules); i++ { + storedExpr := existingPolicy.Rules[i].Expression + if storedExpr == "" || storedExpr == "true" { + continue + } + hasMasked, appErr := a.expressionHasMaskedValuesForCaller(rctx, storedExpr, callerID) + if appErr != nil { + return appErr + } + if hasMasked { + return model.NewAppError("mergeStoredPolicyExpressions", "app.pap.save_policy.masked_rule_deleted", nil, + "cannot remove a rule that contains attribute values you do not hold", http.StatusForbidden) + } + } + } + + return nil +} + +// expressionHasMaskedValuesForCaller reports whether storedExpr contains any value the caller cannot see. +func (a *App) expressionHasMaskedValuesForCaller(rctx request.CTX, storedExpr, callerID string) (bool, *model.AppError) { + maskedAST, appErr := a.GetMaskedVisualAST(rctx, storedExpr, callerID) + if appErr != nil { + return false, appErr + } + for _, cond := range maskedAST.Conditions { + if cond.HasMaskedValues { + return true, nil + } + } + return false, nil +} + +// mergeExpressionWithMaskedValues re-injects hidden values into submittedExpr and +// returns 403 if the caller dropped a condition with values they cannot see. +// +// Two fail-closed shortcuts before the merge: +// 1. Caller has no masked values on storedExpr → return submitted as-is. +// 2. storedExpr isn't faithfully representable by the Visual AST (|| or grouping +// would flatten into ANDs on rebuild) → accept only no-op saves (e.g., rename), +// reject real edits. Role-neutral: masking is attribute-based, so a sysadmin +// without the values lands here too. +// +// Stopgap until the canonical CEL AST walker refactor. +func (a *App) mergeExpressionWithMaskedValues(rctx request.CTX, policyID, submittedExpr, storedExpr, callerID string) (string, *model.AppError) { + hasMasked, appErr := a.expressionHasMaskedValuesForCaller(rctx, storedExpr, callerID) + if appErr != nil { + return "", appErr + } + if !hasMasked { + return submittedExpr, nil + } + + submittedAST, appErr := a.ExpressionToVisualAST(rctx, submittedExpr) + if appErr != nil { + return "", appErr + } + + storedAST, appErr := a.ExpressionToVisualAST(rctx, storedExpr) + if appErr != nil { + return "", appErr + } + + if !isVisualASTRepresentable(storedExpr, storedAST) { + masked, maskErr := a.GetMaskedExpression(rctx, storedExpr, callerID) + if maskErr != nil { + return "", maskErr + } + if normalizedEqual(submittedExpr, masked) { + // no-op edit (e.g., rename) — keep stored expression as-is + return storedExpr, nil + } + rctx.Logger().Info("save refused: stored rule not representable by Visual AST", + mlog.String("policy_id", policyID), + mlog.String("caller_id", callerID), + ) + return "", model.NewAppError("mergeExpressionWithMaskedValues", + "app.pap.save_policy.advanced_expression_blocked", nil, + "this rule expression cannot be safely edited while restricted values are present", + http.StatusForbidden) + } + + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + return "", model.NewAppError("mergeExpressionWithMaskedValues", "app.pap.merge_expression.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + cpaGroupID := cpaGroup.ID + + rctxWithCaller := RequestContextWithCallerID(rctx, callerID) + + // Pre-fetch fields once for all stored conditions. We require every referenced field + // to resolve — proceeding with an incomplete map would silently strip hidden values + // from stored conditions and bypass the masked-condition-delete block. + fieldsByName := a.fetchConditionFields(rctxWithCaller, storedAST.Conditions, cpaGroupID) + if appErr := requireAllFieldsResolved(rctxWithCaller, storedAST.Conditions, fieldsByName); appErr != nil { + return "", appErr + } + + // Count submitted conditions per attribute. A simple set isn't enough because the parser + // can produce two conditions on the same attribute (e.g. `attr in [...] && attr == "x"`); + // dropping one of them while keeping the other must still trigger the deletion guard if + // the dropped condition had hidden values. + submittedCounts := make(map[string]int, len(submittedAST.Conditions)) + for _, cond := range submittedAST.Conditions { + submittedCounts[cond.Attribute]++ + } + + storedCounts := make(map[string]int, len(storedAST.Conditions)) + for _, cond := range storedAST.Conditions { + storedCounts[cond.Attribute]++ + } + + // Block deletion of any stored condition that has hidden values for this caller. + // We walk stored conditions and, when one with hidden values appears, require that + // the submitted set still has at least as many conditions on the same attribute as + // stored had — otherwise some stored condition was dropped. + for i := range storedAST.Conditions { + hidden := a.getHiddenValues(rctxWithCaller, callerID, &storedAST.Conditions[i], cpaGroupID, fieldsByName) + if len(hidden) == 0 { + continue + } + attr := storedAST.Conditions[i].Attribute + if submittedCounts[attr] < storedCounts[attr] { + return "", model.NewAppError("mergeExpressionWithMaskedValues", "app.pap.save_policy.masked_condition_deleted", nil, + "cannot remove a rule condition that contains attribute values you do not hold", http.StatusForbidden) + } + } + + // Match submitted conditions to stored ones by attribute (in order), merge hidden values. + storedByAttr := make(map[string][]model.Condition) + for _, cond := range storedAST.Conditions { + storedByAttr[cond.Attribute] = append(storedByAttr[cond.Attribute], cond) + } + + matchCount := make(map[string]int) + var mergedConditions []model.Condition + + for _, submitted := range submittedAST.Conditions { + storedList, found := storedByAttr[submitted.Attribute] + if !found { + mergedConditions = append(mergedConditions, submitted) + continue + } + + matchIdx := matchCount[submitted.Attribute] + matchCount[submitted.Attribute]++ + + if matchIdx >= len(storedList) { + mergedConditions = append(mergedConditions, submitted) + continue + } + + stored := storedList[matchIdx] + hiddenValues := a.getHiddenValues(rctxWithCaller, callerID, &stored, cpaGroupID, fieldsByName) + merged := mergeConditionValues(submitted, hiddenValues) + merged.Operator = stored.Operator + merged.AttributeType = stored.AttributeType + // Frontend emits "attr in []" as the placeholder for any fully-masked row + // regardless of the stored operator. After we restore the original operator, + // the value shape may not match (e.g., "==" with a []any value). Normalize + // scalar operators to a single string from the array. + if isScalarOperator(merged.Operator) { + if arr, ok := merged.Value.([]any); ok { + if len(arr) == 0 { + merged.Value = nil + } else if s, ok := arr[0].(string); ok { + merged.Value = s + } + } + } + mergedConditions = append(mergedConditions, merged) + } + + return buildCELFromConditions(mergedConditions), nil +} + +// checkSelfInclusion verifies the caller satisfies all policy rules after their edit. +func (a *App) checkSelfInclusion(rctx request.CTX, policy *model.AccessControlPolicy, callerID string) *model.AppError { + for _, rule := range policy.Rules { + if rule.Expression == "" || rule.Expression == "true" { + continue + } + + matches, appErr := a.ValidateExpressionAgainstRequester(rctx, rule.Expression, callerID) + if appErr != nil { + return appErr + } + if !matches { + return model.NewAppError("CreateOrUpdateAccessControlPolicy", + "app.pap.save_policy.self_exclusion", nil, + "You do not satisfy one or more conditions in this policy.", http.StatusForbidden) + } + } + + return nil +} + func (a *App) DeleteAccessControlPolicy(rctx request.CTX, id string) *model.AppError { acs := a.Srv().ch.AccessControl if acs == nil { @@ -142,6 +437,20 @@ func (a *App) DeleteAccessControlPolicy(rctx request.CTX, id string) *model.AppE return appErr } + // ABAC is gated at route registration; only check masking here. + if a.Config().FeatureFlags.AttributeValueMasking { + session := rctx.Session() + if session != nil { + callerID := session.UserId + if hasMasked, appErr := a.policyHasMaskedValuesForCaller(rctx, policy, callerID); appErr != nil { + return appErr + } else if hasMasked { + return model.NewAppError("DeleteAccessControlPolicy", "app.pap.delete_policy.masked_values", nil, + "policy contains attribute values you do not hold; you cannot delete this policy", http.StatusForbidden) + } + } + } + var affectedChannelIDs []string if policy != nil && policy.Type != model.AccessControlPolicyTypeChannel { affectedChannelIDs = a.channelPolicyIDsWithImport(rctx, id) @@ -371,14 +680,6 @@ func (a *App) UpdateAccessControlPoliciesActive(rctx request.CTX, updates []mode if err != nil { return nil, model.NewAppError("UpdateAccessControlPoliciesActive", "app.pap.update_access_control_policies_active.app_error", nil, err.Error(), http.StatusInternalServerError) } - - for _, policy := range policies { - // only channel policies use the active state - if policy.Type == model.AccessControlPolicyTypeChannel { - a.publishChannelPolicyEnforcedUpdate(rctx, policy.ID) - } - } - return policies, nil } diff --git a/server/channels/app/access_control_masking.go b/server/channels/app/access_control_masking.go index c64af04d9a6..f82c9128374 100644 --- a/server/channels/app/access_control_masking.go +++ b/server/channels/app/access_control_masking.go @@ -5,7 +5,9 @@ package app import ( "encoding/json" + "fmt" "net/http" + "strconv" "strings" "github.com/mattermost/mattermost/server/public/model" @@ -50,7 +52,9 @@ func (a *App) GetMaskedVisualAST(rctx request.CTX, expression string, callerID s } // fetchConditionFields collects unique field names from conditions and fetches each once. -// Fields that fail lookup are omitted; maskConditionValues treats missing entries as fail-closed. +// Lookup failures are logged and omitted from the returned map; read-path callers treat +// missing entries as fail-closed (mask the value). Write-path callers should additionally +// call requireAllFieldsResolved to refuse to proceed when any referenced field is missing. func (a *App) fetchConditionFields(rctx request.CTX, conditions []model.Condition, cpaGroupID string) map[string]*model.PropertyField { seen := make(map[string]bool) for _, c := range conditions { @@ -77,6 +81,31 @@ func (a *App) fetchConditionFields(rctx request.CTX, conditions []model.Conditio return fields } +// requireAllFieldsResolved returns the generic invalid-value error if any condition +// references a field name missing from fieldsByName. Write-path callers use this to refuse +// the save rather than silently strip hidden values from conditions whose fields could not +// be resolved. We return the same generic 400 used by the rest of write-path validation so +// unknown/deleted fields don't leak an enumeration signal distinct from hidden-value +// rejection — the actual field name is logged for operator diagnostics instead. +func requireAllFieldsResolved(rctx request.CTX, conditions []model.Condition, fieldsByName map[string]*model.PropertyField) *model.AppError { + for _, c := range conditions { + if c.ValueType == model.AttrValue { + continue + } + name := extractFieldName(c.Attribute) + if name == "" { + continue + } + if _, ok := fieldsByName[name]; !ok { + rctx.Logger().Warn("Field referenced by condition could not be resolved during write-path validation", + mlog.String("field_name", name), + ) + return invalidValueError() + } + } + return nil +} + // maskConditionValues applies masking to a single condition in place. // // Masking semantics differ by field type: @@ -246,3 +275,601 @@ func filterConditionValues(condition *model.Condition, visibleNames map[string]s } } } + +// getHiddenValues returns the subset of stored condition values not visible to callerID. +// fieldsByName is pre-fetched by the caller to avoid N+1 lookups; a missing entry is +// treated as fail-closed (no hidden values injected for that condition). +func (a *App) getHiddenValues(rctx request.CTX, callerID string, stored *model.Condition, cpaGroupID string, fieldsByName map[string]*model.PropertyField) []string { + if stored.ValueType == model.AttrValue { + return nil + } + + fieldName := extractFieldName(stored.Attribute) + if fieldName == "" { + return nil + } + + field, ok := fieldsByName[fieldName] + if !ok { + return nil + } + + switch field.GetAccessMode() { + case model.PropertyAccessModeSourceOnly: + return extractStringValues(stored.Value) + case model.PropertyAccessModeSharedOnly: + var visibleNames map[string]struct{} + if field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect { + visibleNames = extractVisibleOptionNames(field) + } else { + visibleNames = a.getCallerTextValues(rctx, callerID, field, cpaGroupID) + } + var hidden []string + for _, val := range extractStringValues(stored.Value) { + if _, visible := visibleNames[val]; !visible { + hidden = append(hidden, val) + } + } + return hidden + default: + return nil + } +} + +// isScalarOperator reports whether the operator expects a single value (not a list). +// Used by merge-on-save to normalize the value shape after restoring the stored operator. +func isScalarOperator(op string) bool { + switch op { + case "==", "!=", ">", ">=", "<", "<=", "contains", "startsWith", "endsWith": + return true + } + return false +} + +// mergeConditionValues appends hiddenValues into the submitted condition's values, +// deduplicating. A nil submitted value is restored from hidden values alone. +func mergeConditionValues(submitted model.Condition, hiddenValues []string) model.Condition { + if len(hiddenValues) == 0 { + return submitted + } + + merged := submitted + + switch v := submitted.Value.(type) { + case []any: + // Strip the masked-token sentinel from submitted values: it's the + // server's own placeholder for hidden values (from a masked GET), + // not a real value, and we're about to re-inject the actual stored + // hidden values from hiddenValues. + seen := make(map[string]struct{}) + cleaned := make([]any, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + if s == maskedTokenValue { + continue + } + seen[s] = struct{}{} + } + cleaned = append(cleaned, item) + } + result := make([]any, 0, len(cleaned)+len(hiddenValues)) + result = append(result, cleaned...) + for _, hidden := range hiddenValues { + if _, exists := seen[hidden]; !exists { + result = append(result, hidden) + } + } + merged.Value = result + + case string: + // Empty string and the masked-token sentinel both mean "no real value + // submitted here"; restore from hidden values. + if (v == "" || v == maskedTokenValue) && len(hiddenValues) > 0 { + merged.Value = hiddenValues[0] + } + + case nil: + if len(hiddenValues) == 1 { + merged.Value = hiddenValues[0] + } else if len(hiddenValues) > 1 { + result := make([]any, 0, len(hiddenValues)) + for _, h := range hiddenValues { + result = append(result, h) + } + merged.Value = result + } + } + + return merged +} + +// containsNonStringLiteral reports whether the condition value contains any +// non-string element (numeric, boolean, etc.). Used by the write-path to reject +// type-mismatched literals on property-backed conditions — without this guard, +// extractStringValues would silently drop such elements and let invalid CEL +// bypass the source_only / shared_only checks. +func containsNonStringLiteral(value any) bool { + switch v := value.(type) { + case nil, string: + return false + case []any: + for _, item := range v { + if _, ok := item.(string); !ok { + return true + } + } + return false + default: + // numeric, boolean, etc. + return true + } +} + +// extractStringValues converts a condition's Value to a slice of strings. +// Non-string elements are silently dropped — write-path callers should pair +// this with containsNonStringLiteral to reject type-mismatched literals first. +func extractStringValues(value any) []string { + switch v := value.(type) { + case []any: + var result []string + for _, item := range v { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result + case string: + return []string{v} + default: + return nil + } +} + +// buildCELFromConditions reconstructs a CEL expression from conditions, joined with " && ". +func buildCELFromConditions(conditions []model.Condition) string { + if len(conditions) == 0 { + return "true" + } + + parts := make([]string, 0, len(conditions)) + for _, cond := range conditions { + cel := conditionToCEL(cond) + if cel != "" { + parts = append(parts, cel) + } + } + + if len(parts) == 0 { + return "true" + } + + return strings.Join(parts, " && ") +} + +// isVisualASTRepresentable reports whether buildCELFromConditions(ast) round-trips +// back to originalExpr. False means merging would silently rewrite the shape +// (typically || or grouping that the AST flattens into ANDs). Stopgap until the +// canonical AST walker lands. +func isVisualASTRepresentable(originalExpr string, ast *model.VisualExpression) bool { + if ast == nil || len(ast.Conditions) == 0 { + return originalExpr == "" || originalExpr == "true" + } + return normalizedEqual(originalExpr, buildCELFromConditions(ast.Conditions)) +} + +// normalizedEqual compares two CEL expressions modulo whitespace and quote style. +// Unbalanced quotes on either side count as not-equal (fail-closed). +func normalizedEqual(a, b string) bool { + na, okA := normalizeForComparison(a) + if !okA { + return false + } + nb, okB := normalizeForComparison(b) + if !okB { + return false + } + return na == nb +} + +// normalizeForComparison strips whitespace outside string literals and rewrites +// single quotes to double. String contents are preserved verbatim. Returns +// ok=false on unbalanced quotes. +func normalizeForComparison(s string) (string, bool) { + var b strings.Builder + b.Grow(len(s)) + var quote byte // 0 outside string literal; '"' or '\'' inside + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case quote == 0 && (c == '"' || c == '\''): + quote = c + b.WriteByte('"') + case quote != 0 && c == '\\' && i+1 < len(s): + // keep escapes verbatim + b.WriteByte(c) + b.WriteByte(s[i+1]) + i++ + case quote != 0 && c == quote: + b.WriteByte('"') + quote = 0 + case quote == 0 && (c == ' ' || c == '\t' || c == '\n' || c == '\r'): + // drop whitespace outside strings + default: + b.WriteByte(c) + } + } + if quote != 0 { + return "", false + } + return b.String(), true +} + +// conditionToCEL converts a single Condition to its CEL string representation. +func conditionToCEL(cond model.Condition) string { + attr := cond.Attribute + + switch cond.Operator { + case "==", "!=", ">", ">=", "<", "<=": + if cond.Value == nil { + return "" + } + return attr + " " + cond.Operator + " " + celValueLiteral(cond.Value) + + case "in": + values := extractStringValues(cond.Value) + if len(values) == 0 { + return "" + } + if cond.AttributeType == "multiselect" { + // multiselect: "v1" in attr && "v2" in attr + inParts := make([]string, 0, len(values)) + for _, v := range values { + inParts = append(inParts, celStringLiteral(v)+" in "+attr) + } + return strings.Join(inParts, " && ") + } + // select: attr in ["v1", "v2"] + valLiterals := make([]string, 0, len(values)) + for _, v := range values { + valLiterals = append(valLiterals, celStringLiteral(v)) + } + return attr + " in [" + strings.Join(valLiterals, ", ") + "]" + + case "hasAnyOf": + values := extractStringValues(cond.Value) + if len(values) == 0 { + return "" + } + orParts := make([]string, 0, len(values)) + for _, v := range values { + orParts = append(orParts, celStringLiteral(v)+" in "+attr) + } + if len(orParts) == 1 { + return orParts[0] + } + return "(" + strings.Join(orParts, " || ") + ")" + + case "hasAllOf": + values := extractStringValues(cond.Value) + if len(values) == 0 { + return "" + } + andParts := make([]string, 0, len(values)) + for _, v := range values { + andParts = append(andParts, celStringLiteral(v)+" in "+attr) + } + return strings.Join(andParts, " && ") + + case "contains", "startsWith", "endsWith": + if cond.Value == nil { + return "" + } + return attr + "." + cond.Operator + "(" + celValueLiteral(cond.Value) + ")" + + default: + if cond.Value == nil { + return "" + } + return attr + " " + cond.Operator + " " + celValueLiteral(cond.Value) + } +} + +// celStringLiteral wraps s in a CEL-compatible double-quoted string literal. +// strconv.Quote produces Go syntax that overlaps with CEL's escape grammar +// (backslash, double quote, \a \b \f \n \r \t \v, \xHH, \uHHHH, \UHHHHHHHH), +// so it safely round-trips strings containing control characters, embedded +// quotes, or non-ASCII content — none of which the previous naive ReplaceAll +// handled. Attribute values that legitimately contain newlines or tabs would +// have produced broken CEL otherwise. +func celStringLiteral(s string) string { + return strconv.Quote(s) +} + +// celValueLiteral returns the CEL literal for a condition value. +func celValueLiteral(value any) string { + switch v := value.(type) { + case string: + return celStringLiteral(v) + case float64: + // 'g' with precision -1 produces the shortest representation that + // round-trips back to v exactly. Avoids the precision loss from + // fmt.Sprintf("%f") which rounds to six fractional digits. + return strconv.FormatFloat(v, 'g', -1, 64) + case int: + return fmt.Sprintf("%d", v) + case int64: + return fmt.Sprintf("%d", v) + case bool: + if v { + return "true" + } + return "false" + case nil: + return "null" + default: + return fmt.Sprintf("%v", v) + } +} + +// maskedTokenValue is the sentinel the frontend uses for masked values; never a valid attribute value. +const maskedTokenValue = "--------" + +// validatePolicyExpressionValues checks that all submitted literal values are held by the caller. +// Returns the same generic error for every rejection to prevent value enumeration. +func (a *App) validatePolicyExpressionValues(rctx request.CTX, policy *model.AccessControlPolicy, callerID string) *model.AppError { + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + return model.NewAppError("validatePolicyExpressionValues", "app.pap.validate_expression_values.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + cpaGroupID := cpaGroup.ID + + rctxWithCaller := RequestContextWithCallerID(rctx, callerID) + + // Parse all rule ASTs first and collect every referenced field so we can + // pre-fetch in a single pass, avoiding N+1 lookups across conditions. + rulesASTs := make([]*model.VisualExpression, 0, len(policy.Rules)) + var allConditions []model.Condition + for _, rule := range policy.Rules { + if rule.Expression == "" || rule.Expression == "true" { + continue + } + visualAST, appErr := a.ExpressionToVisualAST(rctx, rule.Expression) + if appErr != nil { + return appErr + } + rulesASTs = append(rulesASTs, visualAST) + allConditions = append(allConditions, visualAST.Conditions...) + } + + fieldsByName := a.fetchConditionFields(rctxWithCaller, allConditions, cpaGroupID) + if appErr := requireAllFieldsResolved(rctxWithCaller, allConditions, fieldsByName); appErr != nil { + return appErr + } + + for _, visualAST := range rulesASTs { + for _, cond := range visualAST.Conditions { + if appErr := a.validateConditionValues(rctxWithCaller, &cond, cpaGroupID, fieldsByName); appErr != nil { + return appErr + } + } + } + + return nil +} + +// invalidValueError returns the same generic 400 for all write-path rejections (no enumeration leakage). +func invalidValueError() *model.AppError { + return model.NewAppError("validatePolicyExpressionValues", "app.pap.save_policy.invalid_value", nil, "Invalid value.", http.StatusBadRequest) +} + +// validateConditionValues checks that all literal values in a single condition are held by the caller. +// fieldsByName is pre-fetched by the caller to avoid N+1 lookups; a missing entry means the field +// could not be resolved (deleted, or DB error at prefetch time) — rejected with the generic error. +func (a *App) validateConditionValues(rctx request.CTX, cond *model.Condition, cpaGroupID string, fieldsByName map[string]*model.PropertyField) *model.AppError { + if cond.ValueType == model.AttrValue { + return nil + } + + // The masked-token sentinel is what the server itself emits when masking the + // raw CEL of policy GET / search responses. If the frontend round-trips a GET + // response back to us unchanged (e.g. the admin only modified channel + // assignment, not the rules), it will appear here. Skip it during validation; + // mergeConditionValues will strip it from the merged result and re-inject the + // actual hidden values from the stored policy. + values := extractStringValues(cond.Value) + nonTokenValues := make([]string, 0, len(values)) + for _, v := range values { + if v != maskedTokenValue { + nonTokenValues = append(nonTokenValues, v) + } + } + + fieldName := extractFieldName(cond.Attribute) + if fieldName == "" { + return nil + } + + field, ok := fieldsByName[fieldName] + if !ok { + return invalidValueError() // reject unknown fields to prevent probing + } + + // Property-backed conditions must use string literals. Numeric / boolean values + // would be silently dropped by extractStringValues above, letting them bypass the + // source_only / shared_only checks. Reject them with the same generic error. + if containsNonStringLiteral(cond.Value) { + return invalidValueError() + } + + switch field.GetAccessMode() { + case model.PropertyAccessModePublic: + return nil + case model.PropertyAccessModeSourceOnly: + if len(nonTokenValues) > 0 { + return invalidValueError() + } + return nil + case model.PropertyAccessModeSharedOnly: + var visibleNames map[string]struct{} + if field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect { + visibleNames = extractVisibleOptionNames(field) + } else { + callerID, _ := CallerIDFromRequestContext(rctx) + visibleNames = a.getCallerTextValues(rctx, callerID, field, cpaGroupID) + } + for _, v := range nonTokenValues { + if _, visible := visibleNames[v]; !visible { + return invalidValueError() + } + } + return nil + default: + if len(nonTokenValues) > 0 { + return invalidValueError() + } + return nil + } +} + +func (a *App) GetMaskedExpression(rctx request.CTX, expression string, callerID string) (string, *model.AppError) { + if expression == "" || expression == "true" { + return expression, nil + } + + visualAST, appErr := a.ExpressionToVisualAST(rctx, expression) + if appErr != nil { + return "", appErr + } + + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + return "", appErr + } + cpaGroupID := cpaGroup.ID + + rctxWithCaller := RequestContextWithCallerID(rctx, callerID) + fieldsByName := a.fetchConditionFields(rctxWithCaller, visualAST.Conditions, cpaGroupID) + + for i := range visualAST.Conditions { + a.maskConditionValuesWithToken(rctxWithCaller, callerID, &visualAST.Conditions[i], cpaGroupID, fieldsByName) + } + + return buildCELFromConditions(visualAST.Conditions), nil +} + +// maskConditionValuesWithToken replaces non-held values with the masked token in place, +// preserving expression structure so the visual AST endpoint can still parse it. +// fieldsByName is pre-fetched by the caller to avoid N+1 lookups; a missing entry +// is treated as fail-closed (whole value masked). +func (a *App) maskConditionValuesWithToken(rctx request.CTX, callerID string, condition *model.Condition, cpaGroupID string, fieldsByName map[string]*model.PropertyField) { + if condition.ValueType == model.AttrValue { + return + } + + fieldName := extractFieldName(condition.Attribute) + if fieldName == "" { + return + } + + field, ok := fieldsByName[fieldName] + if !ok { + condition.Value = maskedTokenValue // fail closed + return + } + + switch field.GetAccessMode() { + case model.PropertyAccessModePublic: + return + case model.PropertyAccessModeSourceOnly: + condition.Value = maskedTokenValue + case model.PropertyAccessModeSharedOnly: + var visibleNames map[string]struct{} + if field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect { + visibleNames = extractVisibleOptionNames(field) + } else { + visibleNames = a.getCallerTextValues(rctx, callerID, field, cpaGroupID) + } + replaceHiddenValuesWithToken(condition, visibleNames) + default: + condition.Value = maskedTokenValue + } +} + +// replaceHiddenValuesWithToken keeps visible values and appends a single masked token if any were hidden. +// One token regardless of count prevents count-based inference about the number of hidden values. +func replaceHiddenValuesWithToken(condition *model.Condition, visibleNames map[string]struct{}) { + switch v := condition.Value.(type) { + case []any: + var result []any + hasMasked := false + for _, val := range v { + if strVal, ok := val.(string); ok { + if _, visible := visibleNames[strVal]; visible { + result = append(result, val) + } else { + hasMasked = true + } + } else { + result = append(result, val) + } + } + if hasMasked { + result = append(result, maskedTokenValue) + } + condition.Value = result + case string: + if _, visible := visibleNames[v]; !visible { + condition.Value = maskedTokenValue + } + } +} + +// MaskPolicyExpressions masks non-held literal values in all policy rule expressions, in place. +// Fails closed (sets a rule to "true") if its expression cannot be parsed or masked. +func (a *App) MaskPolicyExpressions(rctx request.CTX, policy *model.AccessControlPolicy, callerID string) { + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + rctx.Logger().Warn("MaskPolicyExpressions: failed to resolve CPA group, masking all rules closed", + mlog.Err(appErr), + ) + for i, rule := range policy.Rules { + if rule.Expression == "" || rule.Expression == "true" { + continue + } + policy.Rules[i].Expression = "true" + } + return + } + cpaGroupID := cpaGroup.ID + + rctxWithCaller := RequestContextWithCallerID(rctx, callerID) + + // Parse each rule's AST once and collect all conditions so we can pre-fetch + // every referenced field in a single pass, avoiding N+1 lookups across rules. + asts := make([]*model.VisualExpression, len(policy.Rules)) + var allConditions []model.Condition + for i, rule := range policy.Rules { + if rule.Expression == "" || rule.Expression == "true" { + continue + } + ast, appErr := a.ExpressionToVisualAST(rctx, rule.Expression) + if appErr != nil { + policy.Rules[i].Expression = "true" // fail closed + continue + } + asts[i] = ast + allConditions = append(allConditions, ast.Conditions...) + } + + fieldsByName := a.fetchConditionFields(rctxWithCaller, allConditions, cpaGroupID) + + for i, ast := range asts { + if ast == nil { + continue + } + for j := range ast.Conditions { + a.maskConditionValuesWithToken(rctxWithCaller, callerID, &ast.Conditions[j], cpaGroupID, fieldsByName) + } + policy.Rules[i].Expression = buildCELFromConditions(ast.Conditions) + } +} diff --git a/server/channels/app/access_control_merge_test.go b/server/channels/app/access_control_merge_test.go new file mode 100644 index 00000000000..4f21749378d --- /dev/null +++ b/server/channels/app/access_control_merge_test.go @@ -0,0 +1,474 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildCELFromConditions(t *testing.T) { + t.Run("empty conditions returns true", func(t *testing.T) { + result := buildCELFromConditions(nil) + assert.Equal(t, "true", result) + }) + + t.Run("equals operator", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Team", Operator: "==", Value: "Engineering", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Team == "Engineering"`, result) + }) + + t.Run("not equals operator", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Location", Operator: "!=", Value: "Building 7", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Location != "Building 7"`, result) + }) + + t.Run("in operator with select field", func(t *testing.T) { + conditions := []model.Condition{ + { + Attribute: "user.attributes.Department", + Operator: "in", + Value: []any{"Sales", "Engineering", "Legal"}, + ValueType: model.LiteralValue, + AttributeType: "select", + }, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Department in ["Sales", "Engineering", "Legal"]`, result) + }) + + t.Run("in operator with multiselect field", func(t *testing.T) { + conditions := []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "in", + Value: []any{"Alpha", "Bravo"}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `"Alpha" in user.attributes.Programs && "Bravo" in user.attributes.Programs`, result) + }) + + t.Run("hasAnyOf operator", func(t *testing.T) { + conditions := []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "hasAnyOf", + Value: []any{"Alpha", "Bravo"}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `("Alpha" in user.attributes.Programs || "Bravo" in user.attributes.Programs)`, result) + }) + + t.Run("hasAnyOf with single value omits parens", func(t *testing.T) { + conditions := []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "hasAnyOf", + Value: []any{"Alpha"}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `"Alpha" in user.attributes.Programs`, result) + }) + + t.Run("hasAllOf operator", func(t *testing.T) { + conditions := []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "hasAllOf", + Value: []any{"Alpha", "Bravo"}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `"Alpha" in user.attributes.Programs && "Bravo" in user.attributes.Programs`, result) + }) + + t.Run("contains operator", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Email", Operator: "contains", Value: "@company.com", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Email.contains("@company.com")`, result) + }) + + t.Run("startsWith operator", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Name", Operator: "startsWith", Value: "Dr.", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Name.startsWith("Dr.")`, result) + }) + + t.Run("endsWith operator", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Email", Operator: "endsWith", Value: ".gov", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Email.endsWith(".gov")`, result) + }) + + t.Run("multiple conditions joined with &&", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Team", Operator: "==", Value: "Engineering", ValueType: model.LiteralValue}, + {Attribute: "user.attributes.Location", Operator: "!=", Value: "Remote", ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Team == "Engineering" && user.attributes.Location != "Remote"`, result) + }) + + t.Run("string with special characters is escaped", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Team", Operator: "==", Value: `Team "Alpha"`, ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Team == "Team \"Alpha\""`, result) + }) + + t.Run("boolean value", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Active", Operator: "==", Value: true, ValueType: model.LiteralValue}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, `user.attributes.Active == true`, result) + }) + + t.Run("empty in-list produces no output", func(t *testing.T) { + conditions := []model.Condition{ + {Attribute: "user.attributes.Department", Operator: "in", Value: []any{}, ValueType: model.LiteralValue, AttributeType: "select"}, + } + result := buildCELFromConditions(conditions) + assert.Equal(t, "true", result) + }) + + t.Run("masked token in values produces valid CEL", func(t *testing.T) { + conds := []model.Condition{{ + Attribute: "user.attributes.Program", + Operator: "in", + Value: []any{"Alpha", maskedTokenValue}, + ValueType: model.LiteralValue, + AttributeType: "select", + }} + result := buildCELFromConditions(conds) + assert.Contains(t, result, "Alpha") + assert.Contains(t, result, maskedTokenValue) + }) +} + +func TestNormalizeForComparison(t *testing.T) { + t.Run("strips whitespace outside string literals", func(t *testing.T) { + got, ok := normalizeForComparison(` a == "b" `) + require.True(t, ok) + assert.Equal(t, `a=="b"`, got) + }) + + t.Run("preserves whitespace inside string literals", func(t *testing.T) { + got, ok := normalizeForComparison(`a == "hello world"`) + require.True(t, ok) + assert.Equal(t, `a=="hello world"`, got) + }) + + t.Run("canonicalizes single quotes to double quotes", func(t *testing.T) { + single, ok := normalizeForComparison(`a == 'foo'`) + require.True(t, ok) + double, ok := normalizeForComparison(`a == "foo"`) + require.True(t, ok) + assert.Equal(t, single, double) + }) + + t.Run("preserves escape sequences inside string literals", func(t *testing.T) { + got, ok := normalizeForComparison(`a == "he said \"hi\""`) + require.True(t, ok) + assert.Equal(t, `a=="he said \"hi\""`, got) + }) + + t.Run("unbalanced quote returns ok=false", func(t *testing.T) { + _, ok := normalizeForComparison(`a == "unterminated`) + assert.False(t, ok) + }) + + t.Run("empty string normalizes to empty", func(t *testing.T) { + got, ok := normalizeForComparison("") + require.True(t, ok) + assert.Equal(t, "", got) + }) +} + +func TestIsVisualASTRepresentable(t *testing.T) { + t.Run("empty AST on empty expression is representable", func(t *testing.T) { + assert.True(t, isVisualASTRepresentable("", &model.VisualExpression{})) + }) + + t.Run("empty AST on 'true' is representable", func(t *testing.T) { + assert.True(t, isVisualASTRepresentable("true", &model.VisualExpression{})) + }) + + t.Run("simple equals condition round-trips cleanly", func(t *testing.T) { + ast := &model.VisualExpression{Conditions: []model.Condition{ + {Attribute: "user.attributes.team", Operator: "==", Value: "Engineering", ValueType: model.LiteralValue}, + }} + assert.True(t, isVisualASTRepresentable(`user.attributes.team == "Engineering"`, ast)) + }) + + t.Run("simple AND chain of two conditions round-trips cleanly", func(t *testing.T) { + ast := &model.VisualExpression{Conditions: []model.Condition{ + {Attribute: "user.attributes.team", Operator: "==", Value: "Engineering", ValueType: model.LiteralValue}, + {Attribute: "user.attributes.role", Operator: "==", Value: "Admin", ValueType: model.LiteralValue}, + }} + assert.True(t, isVisualASTRepresentable( + `user.attributes.team == "Engineering" && user.attributes.role == "Admin"`, + ast, + )) + }) + + t.Run("|| in original but AST flattens to AND is NOT representable", func(t *testing.T) { + // Pretend the parser flattened `a == "X" || b == "Y"` into two AND-joined + // conditions. The round-trip would emit `&&`, mismatch detected. + ast := &model.VisualExpression{Conditions: []model.Condition{ + {Attribute: "user.attributes.team", Operator: "==", Value: "X", ValueType: model.LiteralValue}, + {Attribute: "user.attributes.role", Operator: "==", Value: "Y", ValueType: model.LiteralValue}, + }} + assert.False(t, isVisualASTRepresentable( + `user.attributes.team == "X" || user.attributes.role == "Y"`, + ast, + )) + }) + + t.Run("grouping in original is NOT representable when AST flattens it", func(t *testing.T) { + ast := &model.VisualExpression{Conditions: []model.Condition{ + {Attribute: "user.attributes.a", Operator: "==", Value: "1", ValueType: model.LiteralValue}, + {Attribute: "user.attributes.b", Operator: "==", Value: "2", ValueType: model.LiteralValue}, + {Attribute: "user.attributes.c", Operator: "==", Value: "3", ValueType: model.LiteralValue}, + }} + assert.False(t, isVisualASTRepresentable( + `(user.attributes.a == "1" && user.attributes.b == "2") || user.attributes.c == "3"`, + ast, + )) + }) + + t.Run("hasAnyOf with multiple values is representable (|| within a single condition)", func(t *testing.T) { + // `("Alpha" in attr || "Bravo" in attr)` is the canonical serialization + // for a hasAnyOf condition. It contains || syntactically but the AST + // reduces it to one condition that round-trips identically. + ast := &model.VisualExpression{Conditions: []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "hasAnyOf", + Value: []any{"Alpha", "Bravo"}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + }} + assert.True(t, isVisualASTRepresentable( + `("Alpha" in user.attributes.Programs || "Bravo" in user.attributes.Programs)`, + ast, + )) + }) + + t.Run("unbalanced quote in original is NOT representable", func(t *testing.T) { + ast := &model.VisualExpression{Conditions: []model.Condition{ + {Attribute: "user.attributes.team", Operator: "==", Value: "X", ValueType: model.LiteralValue}, + }} + assert.False(t, isVisualASTRepresentable(`user.attributes.team == "unterminated`, ast)) + }) +} + +func TestExtractStringValues(t *testing.T) { + t.Run("slice of strings", func(t *testing.T) { + result := extractStringValues([]any{"Alpha", "Bravo", "Charlie"}) + assert.Equal(t, []string{"Alpha", "Bravo", "Charlie"}, result) + }) + + t.Run("single string", func(t *testing.T) { + result := extractStringValues("Alpha") + assert.Equal(t, []string{"Alpha"}, result) + }) + + t.Run("nil", func(t *testing.T) { + result := extractStringValues(nil) + assert.Nil(t, result) + }) + + t.Run("mixed types in slice", func(t *testing.T) { + result := extractStringValues([]any{"Alpha", 42, "Bravo"}) + assert.Equal(t, []string{"Alpha", "Bravo"}, result) + }) + + t.Run("non-string non-slice", func(t *testing.T) { + result := extractStringValues(42) + assert.Nil(t, result) + }) +} + +func TestCelStringLiteral(t *testing.T) { + assert.Equal(t, `"hello"`, celStringLiteral("hello")) + assert.Equal(t, `"hello \"world\""`, celStringLiteral(`hello "world"`)) + assert.Equal(t, `"path\\to\\file"`, celStringLiteral(`path\to\file`)) + assert.Equal(t, `""`, celStringLiteral("")) + + // Control characters must be escaped or the emitted CEL literal won't parse. + assert.Equal(t, `"line1\nline2"`, celStringLiteral("line1\nline2")) + assert.Equal(t, `"col1\tcol2"`, celStringLiteral("col1\tcol2")) + assert.Equal(t, `"carriage\rreturn"`, celStringLiteral("carriage\rreturn")) +} + +func TestCelValueLiteral(t *testing.T) { + assert.Equal(t, `"hello"`, celValueLiteral("hello")) + assert.Equal(t, "true", celValueLiteral(true)) + assert.Equal(t, "false", celValueLiteral(false)) + assert.Equal(t, "42", celValueLiteral(int(42))) + assert.Equal(t, "42", celValueLiteral(int64(42))) + assert.Equal(t, "3.14", celValueLiteral(float64(3.14))) + assert.Equal(t, "null", celValueLiteral(nil)) + + // Float precision must round-trip — %f would round 0.123456789 to 0.123457. + assert.Equal(t, "0.123456789", celValueLiteral(float64(0.123456789))) +} + +func TestContainsNonStringLiteral(t *testing.T) { + assert.False(t, containsNonStringLiteral(nil)) + assert.False(t, containsNonStringLiteral("Alpha")) + assert.False(t, containsNonStringLiteral([]any{"Alpha", "Bravo"})) + + assert.True(t, containsNonStringLiteral(float64(1))) + assert.True(t, containsNonStringLiteral(true)) + assert.True(t, containsNonStringLiteral(int64(7))) + assert.True(t, containsNonStringLiteral([]any{"Alpha", 1.0})) +} + +func TestConditionToCEL_NilValue(t *testing.T) { + // A condition whose Value was masked to nil (e.g. all options hidden) must be dropped + // rather than emitting `attr == null`, which is invalid CEL for string attributes. + nilValueOps := []string{"==", "!=", ">", ">=", "<", "<=", "contains", "startsWith", "endsWith", "unknownOp"} + for _, op := range nilValueOps { + cond := model.Condition{ + Attribute: "user.attributes.Clearance", + Operator: op, + Value: nil, + ValueType: model.LiteralValue, + } + assert.Equal(t, "", conditionToCEL(cond), "operator %q with nil value must produce empty string", op) + } +} + +func TestConditionToCEL_UnknownOperatorWithValue(t *testing.T) { + // An unknown operator with a non-nil value produces a best-effort CEL expression. + // buildCELFromConditions will include it as-is; if the operator is truly unknown + // the downstream CEL engine will reject the expression during validation. + // This documents the intended (pass-through) behaviour for forward-compatibility. + cond := model.Condition{ + Attribute: "user.attributes.Clearance", + Operator: "futureOp", + Value: "Secret", + ValueType: model.LiteralValue, + } + result := conditionToCEL(cond) + assert.Equal(t, `user.attributes.Clearance futureOp "Secret"`, result) +} + +func TestMergeConditionValues(t *testing.T) { + t.Run("no hidden values returns submitted as-is", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Program", Operator: "in", Value: []any{"Alpha"}} + result := mergeConditionValues(submitted, nil) + assert.Equal(t, []any{"Alpha"}, result.Value) + }) + + t.Run("appends hidden values without duplicates", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Program", Operator: "in", Value: []any{"Alpha"}} + result := mergeConditionValues(submitted, []string{"Bravo", "Charlie"}) + values, ok := result.Value.([]any) + require.True(t, ok) + assert.Len(t, values, 3) + assert.Contains(t, values, "Alpha") + assert.Contains(t, values, "Bravo") + assert.Contains(t, values, "Charlie") + }) + + t.Run("deduplicates overlapping values", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Program", Operator: "in", Value: []any{"Alpha", "Bravo"}} + result := mergeConditionValues(submitted, []string{"Bravo", "Charlie"}) + values, ok := result.Value.([]any) + require.True(t, ok) + assert.Len(t, values, 3) + }) + + t.Run("restores hidden values when submitted is nil (fully-masked placeholder)", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Program", Operator: "in", Value: nil} + result := mergeConditionValues(submitted, []string{"Bravo", "Charlie"}) + values, ok := result.Value.([]any) + require.True(t, ok) + assert.Len(t, values, 2) + }) + + t.Run("restores single hidden value when submitted is nil", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Location", Operator: "==", Value: nil} + result := mergeConditionValues(submitted, []string{"Building 7"}) + assert.Equal(t, "Building 7", result.Value) + }) +} + +func TestGetHiddenValues(t *testing.T) { + var a *App + + options := []any{ + map[string]any{"id": "id1", "name": "Alpha"}, + map[string]any{"id": "id2", "name": "Bravo"}, + } + makeField := func(accessMode string, fieldType model.PropertyFieldType) *model.PropertyField { + attrs := model.StringInterface{model.PropertyAttrsAccessMode: accessMode} + if fieldType == model.PropertyFieldTypeSelect || fieldType == model.PropertyFieldTypeMultiselect { + attrs[model.PropertyFieldAttributeOptions] = options + } + return &model.PropertyField{Type: fieldType, Attrs: attrs} + } + + t.Run("AttrValue condition: returns nil immediately", func(t *testing.T) { + stored := &model.Condition{Attribute: "user.attributes.Team", Value: "user.attributes.Dept", ValueType: model.AttrValue} + assert.Nil(t, a.getHiddenValues(nil, "caller", stored, "", nil)) + }) + + t.Run("field missing from prefetch map: returns nil (fail closed)", func(t *testing.T) { + stored := &model.Condition{Attribute: "user.attributes.Program", Value: []any{"Alpha", "Bravo"}, ValueType: model.LiteralValue} + assert.Nil(t, a.getHiddenValues(nil, "caller", stored, "", map[string]*model.PropertyField{})) + }) + + t.Run("source_only: all stored values treated as hidden", func(t *testing.T) { + stored := &model.Condition{Attribute: "user.attributes.Clearance", Value: []any{"Top Secret", "Secret"}, ValueType: model.LiteralValue} + fields := map[string]*model.PropertyField{"Clearance": makeField(model.PropertyAccessModeSourceOnly, model.PropertyFieldTypeSelect)} + result := a.getHiddenValues(nil, "caller", stored, "", fields) + assert.Equal(t, []string{"Top Secret", "Secret"}, result) + }) + + t.Run("shared_only select: values absent from options are hidden", func(t *testing.T) { + stored := &model.Condition{Attribute: "user.attributes.Program", Value: []any{"Alpha", "Charlie"}, ValueType: model.LiteralValue} + fields := map[string]*model.PropertyField{"Program": makeField(model.PropertyAccessModeSharedOnly, model.PropertyFieldTypeSelect)} + result := a.getHiddenValues(nil, "caller", stored, "", fields) + assert.Equal(t, []string{"Charlie"}, result) + }) + + t.Run("public field: no values hidden", func(t *testing.T) { + stored := &model.Condition{Attribute: "user.attributes.Dept", Value: []any{"Eng", "Sales"}, ValueType: model.LiteralValue} + fields := map[string]*model.PropertyField{"Dept": makeField(model.PropertyAccessModePublic, model.PropertyFieldTypeSelect)} + result := a.getHiddenValues(nil, "caller", stored, "", fields) + assert.Nil(t, result) + }) +} diff --git a/server/channels/app/access_control_test.go b/server/channels/app/access_control_test.go index 4f8c11bbe79..fa5e6ca210c 100644 --- a/server/channels/app/access_control_test.go +++ b/server/channels/app/access_control_test.go @@ -319,6 +319,156 @@ func TestDeleteAccessControlPolicy(t *testing.T) { mockChannelStore.AssertNotCalled(t, "InvalidateChannel", mock.Anything) mockChannelStore.AssertNotCalled(t, "Get", mock.Anything, mock.Anything) }) + + t.Run("Caller with masked values is blocked from deleting (403)", func(t *testing.T) { + // When AttributeValueMasking is on and the caller cannot see all values in the + // policy, the delete must be refused with the masked_values 403. This closes + // the gap where a delegated admin could remove a policy whose conditions they + // could not audit. Forcing an unknown-field reference in the rule makes + // GetMaskedVisualAST fail-closed (HasMaskedValues=true) without requiring a + // full CPA setup for the test. + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.AttributeBasedAccessControl = true + cfg.FeatureFlags.AttributeValueMasking = true + }).InitBasic(t) + + callerID := model.NewId() + th.Context = th.Context.WithSession(&model.Session{UserId: callerID, Id: model.NewId()}).(*request.Context) + + policyID := model.NewId() + sensitivePolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeChannel, + Version: model.AccessControlPolicyVersionV0_3, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: `user.attributes.f_unknown_field == "Secret"`}, + }, + } + + mockAccessControl := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockAccessControl + mockAccessControl.On("GetPolicy", th.Context, policyID).Return(sensitivePolicy, nil).Once() + // Force GetMaskedVisualAST → maskConditionValues → fail-closed (unknown field). + mockAccessControl.On("ExpressionToVisualAST", mock.Anything, mock.Anything).Return(&model.VisualExpression{ + Conditions: []model.Condition{ + {Attribute: "user.attributes.f_unknown_field", Operator: "==", Value: "Secret", ValueType: model.LiteralValue}, + }, + }, nil).Maybe() + + appErr := th.App.DeleteAccessControlPolicy(th.Context, policyID) + require.NotNil(t, appErr) + require.Equal(t, http.StatusForbidden, appErr.StatusCode) + require.Equal(t, "app.pap.delete_policy.masked_values", appErr.Id) + + mockAccessControl.AssertNotCalled(t, "DeletePolicy", mock.Anything, mock.Anything) + mockAccessControl.AssertExpectations(t) + }) + + t.Run("Masking flag off: delete proceeds for callers that would otherwise be blocked", func(t *testing.T) { + // Belt-and-braces: with AttributeValueMasking off, the masking guard must not + // fire — the policy deletes normally even if the caller wouldn't have seen all + // values. Guards against accidentally inverting the flag condition. + thMock := SetupWithStoreMock(t) + // Note: SetupWithStoreMock doesn't take a config callback. Feature flags + // default to false, which is exactly the state this test wants. + + thMock.Context = thMock.Context.WithSession(&model.Session{UserId: model.NewId(), Id: model.NewId()}).(*request.Context) + + channelID := model.NewId() + channelPolicy := &model.AccessControlPolicy{ + ID: channelID, + Type: model.AccessControlPolicyTypeChannel, + Version: model.AccessControlPolicyVersionV0_3, + } + + mockStore := thMock.App.Srv().Store().(*storemocks.Store) + mockChannelStore := storemocks.ChannelStore{} + mockStore.On("Channel").Return(&mockChannelStore) + mockChannelStore.On("InvalidateChannel", channelID).Once() + mockChannelStore.On("Get", channelID, true).Return(&model.Channel{Id: channelID, Type: model.ChannelTypePrivate}, nil).Once() + + mockAccessControl := &mocks.AccessControlServiceInterface{} + thMock.App.Srv().ch.AccessControl = mockAccessControl + mockAccessControl.On("GetPolicy", thMock.Context, channelID).Return(channelPolicy, nil).Once() + mockAccessControl.On("DeletePolicy", thMock.Context, channelID).Return(nil).Once() + + appErr := thMock.App.DeleteAccessControlPolicy(thMock.Context, channelID) + require.Nil(t, appErr) + mockAccessControl.AssertExpectations(t) + mockChannelStore.AssertExpectations(t) + }) +} + +// TestCheckSelfInclusion verifies the self-exclusion guard: non-admin callers must +// satisfy their own policy after saving, or the save is refused with 403 +// self_exclusion. Sysadmins are exempt at the call site +// (CreateOrUpdateAccessControlPolicy), not inside checkSelfInclusion itself — this +// test exercises the function directly. +func TestCheckSelfInclusion(t *testing.T) { + t.Run("caller who satisfies the policy passes", func(t *testing.T) { + th := Setup(t).InitBasic(t) + callerID := th.BasicUser.Id + + policy := &model.AccessControlPolicy{ + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: `user.attributes.team == "ops"`}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + // QueryUsersForExpression returns the caller → matches → no error. + mockACS.On("QueryUsersForExpression", mock.Anything, mock.Anything, mock.Anything). + Return([]*model.User{{Id: callerID}}, int64(1), nil).Once() + + appErr := th.App.checkSelfInclusion(th.Context, policy, callerID) + require.Nil(t, appErr) + mockACS.AssertExpectations(t) + }) + + t.Run("caller who does not satisfy the policy is rejected with 403", func(t *testing.T) { + th := Setup(t).InitBasic(t) + callerID := th.BasicUser.Id + + policy := &model.AccessControlPolicy{ + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: `user.attributes.team == "ops"`}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + // No users returned → caller does not satisfy → expect self_exclusion 403. + mockACS.On("QueryUsersForExpression", mock.Anything, mock.Anything, mock.Anything). + Return([]*model.User{}, int64(0), nil).Once() + + appErr := th.App.checkSelfInclusion(th.Context, policy, callerID) + require.NotNil(t, appErr) + require.Equal(t, http.StatusForbidden, appErr.StatusCode) + require.Equal(t, "app.pap.save_policy.self_exclusion", appErr.Id) + mockACS.AssertExpectations(t) + }) + + t.Run("trivial rules (empty / 'true') are skipped without querying", func(t *testing.T) { + th := Setup(t).InitBasic(t) + callerID := th.BasicUser.Id + + policy := &model.AccessControlPolicy{ + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: ""}, + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: "true"}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + // No query should fire for trivial expressions — if it does, the mock will fail + // the test by returning the default zero-value response. + + appErr := th.App.checkSelfInclusion(th.Context, policy, callerID) + require.Nil(t, appErr) + mockACS.AssertNotCalled(t, "QueryUsersForExpression", mock.Anything, mock.Anything, mock.Anything) + }) } func TestGetChannelsForPolicy(t *testing.T) { diff --git a/server/channels/app/access_control_validation_test.go b/server/channels/app/access_control_validation_test.go new file mode 100644 index 00000000000..05823ef7743 --- /dev/null +++ b/server/channels/app/access_control_validation_test.go @@ -0,0 +1,231 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInvalidValueError(t *testing.T) { + err := invalidValueError() + require.NotNil(t, err) + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + assert.Equal(t, 400, err.StatusCode) + assert.Equal(t, "Invalid value.", err.DetailedError) +} + +func TestMaskedTokenConstant(t *testing.T) { + // The masked-token sentinel must be the eight-dash string the frontend + // renders for hidden chips and the server emits when masking raw CEL + // on GET / search responses. + assert.Equal(t, "--------", maskedTokenValue) +} + +func TestGenericErrorConsistency(t *testing.T) { + // All rejection reasons must produce identical errors to prevent enumeration. + err1 := invalidValueError() + err2 := invalidValueError() + + assert.Equal(t, err1.Id, err2.Id) + assert.Equal(t, err1.StatusCode, err2.StatusCode) + assert.Equal(t, err1.DetailedError, err2.DetailedError) +} + +func TestValidateConditionValues(t *testing.T) { + rctx := request.TestContext(t) + + // nil App is safe for every branch except shared_only + text (which calls + // a.getCallerTextValues → a.SearchPropertyValues). Those paths are covered + // by the integration tests in access_control_test.go. + var a *App + + makeField := func(accessMode string, fieldType model.PropertyFieldType, options []any) *model.PropertyField { + attrs := model.StringInterface{model.PropertyAttrsAccessMode: accessMode} + if options != nil { + attrs[model.PropertyFieldAttributeOptions] = options + } + return &model.PropertyField{Name: "Team", Type: fieldType, Attrs: attrs} + } + + selectOptions := []any{ + map[string]any{"id": "id1", "name": "Alpha"}, + map[string]any{"id": "id2", "name": "Bravo"}, + } + + t.Run("AttrValue conditions are skipped (no literals to validate)", func(t *testing.T) { + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "user.attributes.Department", + ValueType: model.AttrValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", nil) + assert.Nil(t, err) + }) + + t.Run("non-attribute references are skipped", func(t *testing.T) { + cond := &model.Condition{ + Attribute: "channel.id", + Operator: "==", + Value: "X", + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", nil) + assert.Nil(t, err) + }) + + t.Run("unknown field is rejected with the generic error", func(t *testing.T) { + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "Alpha", + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", map[string]*model.PropertyField{}) + require.NotNil(t, err) + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + }) + + t.Run("public field allows any value", func(t *testing.T) { + field := makeField(model.PropertyAccessModePublic, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "anything", + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", fields) + assert.Nil(t, err) + }) + + t.Run("source_only field rejects any literal value", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSourceOnly, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "Alpha", + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", fields) + require.NotNil(t, err) + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + }) + + t.Run("source_only field allows the masked-token sentinel (round-tripped from GET)", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSourceOnly, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: maskedTokenValue, + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", fields) + assert.Nil(t, err, "the sentinel is stripped/re-injected at merge; validation must let it through") + }) + + t.Run("shared_only select: held value passes, non-held rejected, token allowed", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSharedOnly, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + + // "Alpha" is in the visible-options set (caller holds it) + ok := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "Alpha", + ValueType: model.LiteralValue, + } + require.Nil(t, a.validateConditionValues(rctx, ok, "groupID", fields)) + + // "Charlie" is not in the visible-options set → rejected + bad := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: "Charlie", + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, bad, "groupID", fields) + require.NotNil(t, err) + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + + // Masked-token sentinel passes through (handled by merge, not validation) + tokenCond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: maskedTokenValue, + ValueType: model.LiteralValue, + } + assert.Nil(t, a.validateConditionValues(rctx, tokenCond, "groupID", fields)) + }) + + t.Run("source_only field rejects non-string literals (numeric, bool)", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSourceOnly, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: float64(1), + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", fields) + require.NotNil(t, err, "numeric literal must not slip through extractStringValues silently") + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + }) + + t.Run("shared_only field rejects non-string literals", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSharedOnly, model.PropertyFieldTypeSelect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "==", + Value: true, + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond, "groupID", fields) + require.NotNil(t, err) + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + }) + + t.Run("shared_only multiselect: array values are validated element by element", func(t *testing.T) { + field := makeField(model.PropertyAccessModeSharedOnly, model.PropertyFieldTypeMultiselect, selectOptions) + fields := map[string]*model.PropertyField{"Team": field} + + allHeld, _ := json.Marshal([]any{"Alpha", "Bravo"}) + cond := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "in", + Value: parseAny(t, allHeld), + ValueType: model.LiteralValue, + } + require.Nil(t, a.validateConditionValues(rctx, cond, "groupID", fields)) + + mixed, _ := json.Marshal([]any{"Alpha", "Charlie"}) + cond2 := &model.Condition{ + Attribute: "user.attributes.Team", + Operator: "in", + Value: parseAny(t, mixed), + ValueType: model.LiteralValue, + } + err := a.validateConditionValues(rctx, cond2, "groupID", fields) + require.NotNil(t, err, "any non-held element must trigger rejection") + assert.Equal(t, "app.pap.save_policy.invalid_value", err.Id) + }) +} + +// parseAny round-trips a JSON-encoded value back to the untyped interface{} +// shape that the visual-AST parser produces for condition.Value. +func parseAny(t *testing.T, raw []byte) any { + t.Helper() + var v any + require.NoError(t, json.Unmarshal(raw, &v)) + return v +} diff --git a/server/i18n/en.json b/server/i18n/en.json index b46a9eb05d3..782435e4a87 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -7492,6 +7492,10 @@ "id": "app.pap.delete_policy.app_error", "translation": "Unable to delete access control policy." }, + { + "id": "app.pap.delete_policy.masked_values", + "translation": "You cannot delete this policy because it contains attribute values you do not have permission to view." + }, { "id": "app.pap.expression_to_visual_ast.app_error", "translation": "Could not genereate visual AST from expression." @@ -7536,6 +7540,10 @@ "id": "app.pap.is_ready.app_error", "translation": "Access control service is not ready." }, + { + "id": "app.pap.merge_expression.app_error", + "translation": "Could not merge policy expression." + }, { "id": "app.pap.missing_attribute.app_error", "translation": "An attribute is missing from the expression." @@ -7548,14 +7556,34 @@ "id": "app.pap.query_expression.app_error", "translation": "Could not query for expression." }, + { + "id": "app.pap.save_policy.advanced_expression_blocked", + "translation": "This rule expression cannot be safely edited while restricted values are present." + }, { "id": "app.pap.save_policy.app_error", "translation": "Unable to save access control policy." }, + { + "id": "app.pap.save_policy.invalid_value", + "translation": "Invalid value." + }, + { + "id": "app.pap.save_policy.masked_condition_deleted", + "translation": "You cannot remove a condition that contains attribute values you do not have permission to view." + }, + { + "id": "app.pap.save_policy.masked_rule_deleted", + "translation": "You cannot remove a rule that contains attribute values you do not have permission to view." + }, { "id": "app.pap.save_policy.name_exists.app_error", "translation": "A policy with this name already exists. Please choose a different name." }, + { + "id": "app.pap.save_policy.self_exclusion", + "translation": "You do not satisfy one or more conditions in this policy." + }, { "id": "app.pap.search_access_control_policies.app_error", "translation": "Could not search access control policies." @@ -7568,6 +7596,10 @@ "id": "app.pap.update_access_control_policies_active.app_error", "translation": "Could not update active status of access control policies." }, + { + "id": "app.pap.validate_expression_values.app_error", + "translation": "Could not validate policy expression values." + }, { "id": "app.pdp.access_evaluation.app_error", "translation": "Failed evaluate access control policy." From 9bd77d3fc4d2af3f7f0259a205174e76a1a276e1 Mon Sep 17 00:00:00 2001 From: Julien Tant <785518+JulienTant@users.noreply.github.com> Date: Mon, 18 May 2026 09:01:00 -0700 Subject: [PATCH 25/80] MM-68702: Reject demoting bot accounts to guest (#36487) * MM-68702: Reject demoting bot accounts to guest Deny DemoteUserToGuest when the target is a bot so User Managers cannot degrade bot capabilities via guest conversion without bot administration permissions. Adds API error string and tests. Co-authored-by: Julien Tant * Fix TestDemoteUserToGuest bot subtest: enable bot creation in config Default test config disables bot accounts; enable ServiceSettings EnableBotAccountCreation for the subtest and restore afterward. Co-authored-by: Julien Tant --------- Co-authored-by: Cursor Agent Co-authored-by: Julien Tant --- server/channels/api4/user_test.go | 28 ++++++++++++++++++++++++++++ server/channels/app/user.go | 4 ++++ server/channels/app/user_test.go | 12 ++++++++++++ server/i18n/en.json | 4 ++++ 4 files changed, 48 insertions(+) diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index fddc86cb308..c0024292995 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -6961,6 +6961,34 @@ func TestDemoteUserToGuest(t *testing.T) { require.NoError(t, err) }) + t.Run("cannot demote bot account", func(t *testing.T) { + th.App.Srv().SetLicense(model.NewTestLicense("guest_accounts")) + + prevBotCreation := *th.App.Config().ServiceSettings.EnableBotAccountCreation + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableBotAccountCreation = true + }) + defer th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableBotAccountCreation = prevBotCreation + }) + + createdBot, resp, err := th.SystemAdminClient.CreateBot(context.Background(), &model.Bot{ + Username: "botdemote" + model.NewId(), + DisplayName: "Demote Test Bot", + Description: "test", + }) + require.NoError(t, err) + CheckCreatedStatus(t, resp) + defer func() { + appErr := th.App.PermanentDeleteBot(th.Context, createdBot.UserId) + require.Nil(t, appErr) + }() + + demoteResp, err := th.SystemAdminClient.DemoteUserToGuest(context.Background(), createdBot.UserId) + CheckBadRequestStatus(t, demoteResp) + CheckErrorID(t, err, "api.user.demote_user_to_guest.bot_not_allowed.app_error") + }) + th.TestForSystemAdminAndLocal(t, func(t *testing.T, c *model.Client4) { _, _, err := c.GetUser(context.Background(), user.Id, "") require.NoError(t, err) diff --git a/server/channels/app/user.go b/server/channels/app/user.go index 325e3bdd1a1..22060540f07 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -2742,6 +2742,10 @@ func (a *App) PromoteGuestToUser(rctx request.CTX, user *model.User, requestorId // DemoteUserToGuest Convert user's roles and all his membership's roles from // regular user roles to guest roles. func (a *App) DemoteUserToGuest(rctx request.CTX, user *model.User) *model.AppError { + if user.IsBot { + return model.NewAppError("DemoteUserToGuest", "api.user.demote_user_to_guest.bot_not_allowed.app_error", nil, "", http.StatusBadRequest) + } + demotedUser, nErr := a.ch.srv.userService.DemoteUserToGuest(user) a.InvalidateCacheForUser(user.Id) if nErr != nil { diff --git a/server/channels/app/user_test.go b/server/channels/app/user_test.go index 88b7b19461e..f82cd9915c6 100644 --- a/server/channels/app/user_test.go +++ b/server/channels/app/user_test.go @@ -2012,6 +2012,18 @@ func TestDemoteUserToGuest(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic(t) + t.Run("Must reject bot user", func(t *testing.T) { + bot := th.CreateBot(t) + user, err := th.App.GetUser(bot.UserId) + require.Nil(t, err) + require.True(t, user.IsBot) + + appErr := th.App.DemoteUserToGuest(th.Context, user) + require.NotNil(t, appErr) + assert.Equal(t, "api.user.demote_user_to_guest.bot_not_allowed.app_error", appErr.Id) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + }) + t.Run("Must invalidate channel stats cache when demoting a user", func(t *testing.T) { user := th.CreateUser(t) require.Equal(t, "system_user", user.Roles) diff --git a/server/i18n/en.json b/server/i18n/en.json index 782435e4a87..01bb249870f 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -4654,6 +4654,10 @@ "id": "api.user.demote_user_to_guest.already_guest.app_error", "translation": "Unable to convert the user to guest because is already a guest." }, + { + "id": "api.user.demote_user_to_guest.bot_not_allowed.app_error", + "translation": "Bot accounts cannot be converted to guest accounts." + }, { "id": "api.user.email_to_ldap.not_available.app_error", "translation": "AD/LDAP not available on this server." From 548183d748ada9af25eb8d7a4cced1064cec4532 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Date: Mon, 18 May 2026 22:23:58 +0200 Subject: [PATCH 26/80] Mm 68282 admin ephemeral mode (#36194) * adds feature flag to enable mattermost ephemeral mode * add ephemeral mode config settings to system console When feature flag is set to true a new section for Mobile Ephemeral Mode settings shows under the Mobile Security section in case a valid Enterprise Advanced License is active. * adds Mobile Ephemeral Mode settings playwright tests * improve descriptions for settings * improves error messages and hints * move validation to common helper and add new tests * reverts package-lock.json changes * proper struct alignment * proper message sorting in json file * use generic doc url for MEM section while docs are not ready * Proper formatting for playwright tests * fixes test --- .../lib/src/server/default_config.ts | 7 + .../playwright/lib/src/ui/components/index.ts | 10 +- .../system_console/base_components.ts | 34 +++ .../sections/environment/mobile_security.ts | 53 +++- .../system_console/mobile_security.spec.ts | 228 ++++++++++++++++++ server/config/client.go | 14 ++ server/config/client_test.go | 134 ++++++++++ server/i18n/en.json | 12 + server/public/model/config.go | 61 +++++ server/public/model/config_test.go | 115 +++++++++ server/public/model/feature_flags.go | 5 + .../admin_console/admin_definition.tsx | 70 ++++++ .../admin_definition_helpers.test.tsx | 21 +- .../admin_definition_helpers.tsx | 1 + ..._definition_mobile_ephemeral_mode.test.tsx | 129 ++++++++++ webapp/channels/src/i18n/en.json | 19 ++ .../src/utils/admin_console_index.test.tsx | 2 +- webapp/platform/types/src/config.ts | 8 + 18 files changed, 919 insertions(+), 4 deletions(-) create mode 100644 webapp/channels/src/components/admin_console/admin_definition_mobile_ephemeral_mode.test.tsx diff --git a/e2e-tests/playwright/lib/src/server/default_config.ts b/e2e-tests/playwright/lib/src/server/default_config.ts index 3c8abca6f6a..b4448dd89b9 100644 --- a/e2e-tests/playwright/lib/src/server/default_config.ts +++ b/e2e-tests/playwright/lib/src/server/default_config.ts @@ -790,6 +790,7 @@ const defaultServerConfig: AdminConfig = { IntegratedBoards: false, CJKSearch: false, ManagedChannelCategories: false, + MobileEphemeralMode: true, }, ImportSettings: { Directory: './import', @@ -868,4 +869,10 @@ const defaultServerConfig: AdminConfig = { LLMServiceID: '', }, }, + MobileEphemeralModeSettings: { + Enable: false, + DisconnectionTimeoutSeconds: 60, + OfflinePersistenceTimerHours: 24, + AutoCacheCleanupDays: 7, + }, }; diff --git a/e2e-tests/playwright/lib/src/ui/components/index.ts b/e2e-tests/playwright/lib/src/ui/components/index.ts index babea003a09..1d7e83a3488 100644 --- a/e2e-tests/playwright/lib/src/ui/components/index.ts +++ b/e2e-tests/playwright/lib/src/ui/components/index.ts @@ -54,7 +54,13 @@ import BurnOnReadTimerChip from './channels/burn_on_read_timer_chip'; import BurnOnReadConcealedPlaceholder from './channels/burn_on_read_concealed_placeholder'; import BurnOnReadConfirmationModal from './channels/burn_on_read_confirmation_modal'; // System Console Components -import {AdminSectionPanel, DropdownSetting, RadioSetting, TextInputSetting} from './system_console/base_components'; +import { + AdminSectionPanel, + DropdownSetting, + NumberInputSetting, + RadioSetting, + TextInputSetting, +} from './system_console/base_components'; import DelegatedGranularAdministration from './system_console/sections/user_management/delegated_granular_administration'; import UserDetail from './system_console/sections/user_management/user_detail'; import EditionAndLicense from './system_console/sections/about/edition_and_license'; @@ -132,6 +138,7 @@ const components = { EditionAndLicense, MobileSecurity, Notifications, + NumberInputSetting, RadioSetting, UsersAndTeams, SystemConsoleFeatureDiscovery, @@ -210,6 +217,7 @@ export { EditionAndLicense, MobileSecurity, Notifications, + NumberInputSetting, RadioSetting, UsersAndTeams, SystemConsoleFeatureDiscovery, diff --git a/e2e-tests/playwright/lib/src/ui/components/system_console/base_components.ts b/e2e-tests/playwright/lib/src/ui/components/system_console/base_components.ts index 2cd2c04cc3b..48ed352a8b3 100644 --- a/e2e-tests/playwright/lib/src/ui/components/system_console/base_components.ts +++ b/e2e-tests/playwright/lib/src/ui/components/system_console/base_components.ts @@ -94,6 +94,40 @@ export class TextInputSetting { } } +/** + * Number Input Setting - represents a number input field + * Uses getByRole('spinbutton') since has ARIA role spinbutton + */ +export class NumberInputSetting { + readonly container: Locator; + readonly label: Locator; + readonly input: Locator; + readonly helpText: Locator; + + constructor(container: Locator, labelText: string) { + this.container = container; + this.label = container.getByText(labelText); + this.input = container.getByRole('spinbutton'); + this.helpText = container.locator('.help-text'); + } + + async fill(value: string) { + await this.input.fill(value); + } + + async getValue(): Promise { + return (await this.input.inputValue()) ?? ''; + } + + async clear() { + await this.input.clear(); + } + + async toBeVisible() { + await expect(this.container).toBeVisible(); + } +} + /** * Dropdown Setting - represents a select dropdown */ diff --git a/e2e-tests/playwright/lib/src/ui/components/system_console/sections/environment/mobile_security.ts b/e2e-tests/playwright/lib/src/ui/components/system_console/sections/environment/mobile_security.ts index f171dc11cc8..0eb97a64ec2 100644 --- a/e2e-tests/playwright/lib/src/ui/components/system_console/sections/environment/mobile_security.ts +++ b/e2e-tests/playwright/lib/src/ui/components/system_console/sections/environment/mobile_security.ts @@ -3,7 +3,13 @@ import {Locator, expect} from '@playwright/test'; -import {RadioSetting, TextInputSetting, DropdownSetting, AdminSectionPanel} from '../../base_components'; +import { + RadioSetting, + TextInputSetting, + NumberInputSetting, + DropdownSetting, + AdminSectionPanel, +} from '../../base_components'; /** * System Console -> Environment -> Mobile Security @@ -17,6 +23,7 @@ export default class MobileSecurity { // Panels readonly generalMobileSecurity: GeneralMobileSecurityPanel; readonly microsoftIntune: MicrosoftIntunePanel; + readonly mobileEphemeralMode: MobileEphemeralModePanel; // Save section readonly saveButton: Locator; @@ -33,6 +40,9 @@ export default class MobileSecurity { this.microsoftIntune = new MicrosoftIntunePanel( container.locator('.AdminSectionPanel').filter({hasText: 'Microsoft Intune'}), ); + this.mobileEphemeralMode = new MobileEphemeralModePanel( + container.locator('.AdminSectionPanel').filter({hasText: 'Mobile Ephemeral Mode'}), + ); this.saveButton = container.getByRole('button', {name: 'Save'}); this.errorMessage = container.locator('.error-message'); @@ -77,6 +87,20 @@ export default class MobileSecurity { get clientId() { return this.microsoftIntune.clientId; } + + // Convenience shortcuts for Mobile Ephemeral Mode settings + get enableMobileEphemeralMode() { + return this.mobileEphemeralMode.enableMobileEphemeralMode; + } + get disconnectionTimeout() { + return this.mobileEphemeralMode.disconnectionTimeout; + } + get offlinePersistenceTimer() { + return this.mobileEphemeralMode.offlinePersistenceTimer; + } + get autoCacheCleanup() { + return this.mobileEphemeralMode.autoCacheCleanup; + } } class GeneralMobileSecurityPanel extends AdminSectionPanel { @@ -105,6 +129,33 @@ class GeneralMobileSecurityPanel extends AdminSectionPanel { } } +class MobileEphemeralModePanel extends AdminSectionPanel { + readonly enableMobileEphemeralMode: RadioSetting; + readonly disconnectionTimeout: NumberInputSetting; + readonly offlinePersistenceTimer: NumberInputSetting; + readonly autoCacheCleanup: NumberInputSetting; + + constructor(container: Locator) { + super(container, 'Mobile Ephemeral Mode'); + + this.enableMobileEphemeralMode = new RadioSetting( + this.body.getByRole('group', {name: /Enable Mobile Ephemeral Mode/}), + ); + this.disconnectionTimeout = new NumberInputSetting( + this.body.locator('.form-group').filter({hasText: 'Disconnection Timeout (seconds):'}), + 'Disconnection Timeout (seconds):', + ); + this.offlinePersistenceTimer = new NumberInputSetting( + this.body.locator('.form-group').filter({hasText: 'Offline Persistence Timer (hours):'}), + 'Offline Persistence Timer (hours):', + ); + this.autoCacheCleanup = new NumberInputSetting( + this.body.locator('.form-group').filter({hasText: 'Auto Cache Cleanup (days):'}), + 'Auto Cache Cleanup (days):', + ); + } +} + class MicrosoftIntunePanel extends AdminSectionPanel { readonly enableIntuneMAM: RadioSetting; readonly authProvider: DropdownSetting; diff --git a/e2e-tests/playwright/specs/functional/system_console/mobile_security.spec.ts b/e2e-tests/playwright/specs/functional/system_console/mobile_security.spec.ts index 44053d653bb..e4f64c21a36 100644 --- a/e2e-tests/playwright/specs/functional/system_console/mobile_security.spec.ts +++ b/e2e-tests/playwright/specs/functional/system_console/mobile_security.spec.ts @@ -507,3 +507,231 @@ test('should disable Intune inputs when toggle is off', async ({pw}) => { expect(await systemConsolePage.mobileSecurity.tenantId.input.isDisabled()).toBe(false); expect(await systemConsolePage.mobileSecurity.clientId.input.isDisabled()).toBe(false); }); + +/** + * @objective Verify timer settings are disabled when Mobile Ephemeral Mode is not enabled, and become editable when enabled + */ +test( + 'should disable Mobile Ephemeral Mode sub-settings when toggle is off and enable them when toggle is on', + {tag: '@mobile_ephemeral_mode'}, + async ({pw}) => { + const {adminUser, adminClient} = await pw.initSetup(); + + const license = await adminClient.getClientLicenseOld(); + + test.skip( + license.SkuShortName !== 'advanced', + 'Skipping test - server does not have enterprise advanced license', + ); + + const config = await adminClient.getConfig(); + test.skip( + config.FeatureFlags.MobileEphemeralMode !== true && config.FeatureFlags.MobileEphemeralMode !== 'true', + 'Skipping test - MobileEphemeralMode feature flag is not enabled on the server', + ); + + if (!adminUser) { + throw new Error('Failed to create admin user'); + } + + // # Log in as admin + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + + // # Visit system console + await systemConsolePage.goto(); + await systemConsolePage.toBeVisible(); + + // # Go to Mobile Security section + await systemConsolePage.sidebar.mobileSecurity.click(); + await systemConsolePage.mobileSecurity.toBeVisible(); + + // * Verify Mobile Ephemeral Mode toggle is off by default + await systemConsolePage.mobileSecurity.enableMobileEphemeralMode.toBeFalse(); + + // * Verify all sub-settings are disabled + expect(await systemConsolePage.mobileSecurity.disconnectionTimeout.input.isDisabled()).toBe(true); + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.input.isDisabled()).toBe(true); + expect(await systemConsolePage.mobileSecurity.autoCacheCleanup.input.isDisabled()).toBe(true); + + // # Enable Mobile Ephemeral Mode toggle + await systemConsolePage.mobileSecurity.enableMobileEphemeralMode.selectTrue(); + + // * Verify all sub-settings are now enabled + expect(await systemConsolePage.mobileSecurity.disconnectionTimeout.input.isDisabled()).toBe(false); + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.input.isDisabled()).toBe(false); + expect(await systemConsolePage.mobileSecurity.autoCacheCleanup.input.isDisabled()).toBe(false); + }, +); + +/** + * @objective Verify all Mobile Ephemeral Mode settings persist after save and navigation + */ +test( + 'should save and persist all Mobile Ephemeral Mode settings after navigation', + {tag: '@mobile_ephemeral_mode'}, + async ({pw}) => { + const {adminUser, adminClient} = await pw.initSetup(); + + const license = await adminClient.getClientLicenseOld(); + + test.skip( + license.SkuShortName !== 'advanced', + 'Skipping test - server does not have enterprise advanced license', + ); + + const config = await adminClient.getConfig(); + test.skip( + config.FeatureFlags.MobileEphemeralMode !== true && config.FeatureFlags.MobileEphemeralMode !== 'true', + 'Skipping test - MobileEphemeralMode feature flag is not enabled on the server', + ); + + if (!adminUser) { + throw new Error('Failed to create admin user'); + } + + // # Enable Mobile Ephemeral Mode setting via config API + config.MobileEphemeralModeSettings.Enable = true; + await adminClient.updateConfig(config); + + // # Log in as admin + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + + // # Visit system console + await systemConsolePage.goto(); + await systemConsolePage.toBeVisible(); + + // # Go to Mobile Security section + await systemConsolePage.sidebar.mobileSecurity.click(); + await systemConsolePage.mobileSecurity.toBeVisible(); + + // # Set custom values + await systemConsolePage.mobileSecurity.disconnectionTimeout.fill('120'); + await systemConsolePage.mobileSecurity.offlinePersistenceTimer.fill('48'); + await systemConsolePage.mobileSecurity.autoCacheCleanup.fill('14'); + + // # Save settings + await systemConsolePage.mobileSecurity.save(); + await pw.waitUntil(async () => (await systemConsolePage.mobileSecurity.saveButton.textContent()) === 'Save'); + + // # Navigate away and back + await systemConsolePage.sidebar.users.click(); + await systemConsolePage.users.toBeVisible(); + await systemConsolePage.sidebar.mobileSecurity.click(); + await systemConsolePage.mobileSecurity.toBeVisible(); + + // * Verify Mobile Ephemeral Mode is still enabled + await systemConsolePage.mobileSecurity.enableMobileEphemeralMode.toBeTrue(); + + // * Verify all values persisted correctly + expect(await systemConsolePage.mobileSecurity.disconnectionTimeout.getValue()).toBe('120'); + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.getValue()).toBe('48'); + expect(await systemConsolePage.mobileSecurity.autoCacheCleanup.getValue()).toBe('14'); + }, +); + +/** + * @objective Verify offline persistence timer is disabled when auto cache cleanup is set to 0 (zero-persistence mode) + */ +test( + 'should disable offline persistence timer when auto cache cleanup is set to zero', + {tag: '@mobile_ephemeral_mode'}, + async ({pw}) => { + const {adminUser, adminClient} = await pw.initSetup(); + + const license = await adminClient.getClientLicenseOld(); + + test.skip( + license.SkuShortName !== 'advanced', + 'Skipping test - server does not have enterprise advanced license', + ); + + const config = await adminClient.getConfig(); + test.skip( + config.FeatureFlags.MobileEphemeralMode !== true && config.FeatureFlags.MobileEphemeralMode !== 'true', + 'Skipping test - MobileEphemeralMode feature flag is not enabled on the server', + ); + + if (!adminUser) { + throw new Error('Failed to create admin user'); + } + + // # Enable Mobile Ephemeral Mode setting via config API + config.MobileEphemeralModeSettings.Enable = true; + await adminClient.updateConfig(config); + + // # Log in as admin + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + + // # Visit system console + await systemConsolePage.goto(); + await systemConsolePage.toBeVisible(); + + // # Go to Mobile Security section + await systemConsolePage.sidebar.mobileSecurity.click(); + await systemConsolePage.mobileSecurity.toBeVisible(); + + // * Verify offline persistence timer is enabled + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.input.isDisabled()).toBe(false); + + // # Set auto cache cleanup to 0 + await systemConsolePage.mobileSecurity.autoCacheCleanup.clear(); + await systemConsolePage.mobileSecurity.autoCacheCleanup.fill('0'); + + // * Verify offline persistence timer is now disabled + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.input.isDisabled()).toBe(true); + + // # Set auto cache cleanup back to 7 + await systemConsolePage.mobileSecurity.autoCacheCleanup.clear(); + await systemConsolePage.mobileSecurity.autoCacheCleanup.fill('7'); + + // * Verify offline persistence timer is enabled again + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.input.isDisabled()).toBe(false); + }, +); + +/** + * @objective Verify Mobile Ephemeral Mode settings show correct defaults on first enable + */ +test( + 'should show correct default values when Mobile Ephemeral Mode is first enabled', + {tag: '@mobile_ephemeral_mode'}, + async ({pw}) => { + const {adminUser, adminClient} = await pw.initSetup(); + + const license = await adminClient.getClientLicenseOld(); + + test.skip( + license.SkuShortName !== 'advanced', + 'Skipping test - server does not have enterprise advanced license', + ); + + const config = await adminClient.getConfig(); + test.skip( + config.FeatureFlags.MobileEphemeralMode !== true && config.FeatureFlags.MobileEphemeralMode !== 'true', + 'Skipping test - MobileEphemeralMode feature flag is not enabled on the server', + ); + + if (!adminUser) { + throw new Error('Failed to create admin user'); + } + + // # Log in as admin + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + + // # Visit system console + await systemConsolePage.goto(); + await systemConsolePage.toBeVisible(); + + // # Go to Mobile Security section + await systemConsolePage.sidebar.mobileSecurity.click(); + await systemConsolePage.mobileSecurity.toBeVisible(); + + // # Enable Mobile Ephemeral Mode + await systemConsolePage.mobileSecurity.enableMobileEphemeralMode.selectTrue(); + + // * Verify default values + expect(await systemConsolePage.mobileSecurity.disconnectionTimeout.getValue()).toBe('60'); + expect(await systemConsolePage.mobileSecurity.offlinePersistenceTimer.getValue()).toBe('24'); + expect(await systemConsolePage.mobileSecurity.autoCacheCleanup.getValue()).toBe('7'); + }, +); diff --git a/server/config/client.go b/server/config/client.go index d56c9d7e70e..5b255f1af97 100644 --- a/server/config/client.go +++ b/server/config/client.go @@ -255,6 +255,20 @@ func GenerateClientConfig(c *model.Config, telemetryID string, license *model.Li props["AutoTranslationLanguages"] = "" } props["RestrictDMAndGMAutotranslation"] = strconv.FormatBool(*c.AutoTranslationSettings.RestrictDMAndGM) + + if c.FeatureFlags.MobileEphemeralMode { + ephemeralEnabled := c.MobileEphemeralModeSettings.Enable != nil && *c.MobileEphemeralModeSettings.Enable + props["MobileEphemeralModeEnabled"] = strconv.FormatBool(ephemeralEnabled) + if c.MobileEphemeralModeSettings.DisconnectionTimeoutSeconds != nil { + props["MobileEphemeralModeDisconnectionTimeoutSeconds"] = strconv.Itoa(*c.MobileEphemeralModeSettings.DisconnectionTimeoutSeconds) + } + if c.MobileEphemeralModeSettings.OfflinePersistenceTimerHours != nil { + props["MobileEphemeralModeOfflinePersistenceTimerHours"] = strconv.Itoa(*c.MobileEphemeralModeSettings.OfflinePersistenceTimerHours) + } + if c.MobileEphemeralModeSettings.AutoCacheCleanupDays != nil { + props["MobileEphemeralModeAutoCacheCleanupDays"] = strconv.Itoa(*c.MobileEphemeralModeSettings.AutoCacheCleanupDays) + } + } } } diff --git a/server/config/client_test.go b/server/config/client_test.go index 159d80a8a03..53a9511fd86 100644 --- a/server/config/client_test.go +++ b/server/config/client_test.go @@ -20,6 +20,7 @@ func TestGetClientConfig(t *testing.T) { telemetryID string license *model.License expectedFields map[string]string + absentFields []string }{ { "unlicensed", @@ -48,6 +49,7 @@ func TestGetClientConfig(t *testing.T) { "WebsocketPort": "80", "WebsocketSecurePort": "443", }, + nil, }, { "licensed, but not for theme management", @@ -71,6 +73,7 @@ func TestGetClientConfig(t *testing.T) { "EmailNotificationContentsType": "full", "AllowCustomThemes": "true", }, + nil, }, { "licensed for theme management", @@ -93,6 +96,7 @@ func TestGetClientConfig(t *testing.T) { "EmailNotificationContentsType": "full", "AllowCustomThemes": "false", }, + nil, }, { "licensed for enforcement", @@ -110,6 +114,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnforceMultifactorAuthentication": "true", }, + nil, }, { "default marketplace", @@ -123,6 +128,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IsDefaultMarketplace": "true", }, + nil, }, { "non-default marketplace", @@ -136,6 +142,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IsDefaultMarketplace": "false", }, + nil, }, { "enable ShowFullName prop", @@ -149,6 +156,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ShowFullName": "true", }, + nil, }, { "enable UseAnonymousURLs prop", @@ -162,6 +170,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "UseAnonymousURLs": "true", }, + nil, }, { "Custom groups professional license", @@ -174,6 +183,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableCustomGroups": "true", }, + nil, }, { "Custom groups enterprise license", @@ -186,6 +196,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableCustomGroups": "true", }, + nil, }, { "Custom groups other license", @@ -198,6 +209,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableCustomGroups": "false", }, + nil, }, { "Shared channels other license", @@ -216,6 +228,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ExperimentalSharedChannels": "false", }, + nil, }, { "licensed for shared channels", @@ -234,6 +247,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ExperimentalSharedChannels": "true", }, + nil, }, { "Shared channels professional license", @@ -252,6 +266,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ExperimentalSharedChannels": "true", }, + nil, }, { "disable EnableUserStatuses", @@ -265,6 +280,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableUserStatuses": "false", }, + nil, }, { "Shared channels enterprise license", @@ -283,6 +299,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ExperimentalSharedChannels": "true", }, + nil, }, { "Disable App Bar", @@ -296,6 +313,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "DisableAppBar": "true", }, + nil, }, { "default EnableJoinLeaveMessage", @@ -305,6 +323,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableJoinLeaveMessageByDefault": "true", }, + nil, }, { "disable EnableJoinLeaveMessage", @@ -318,6 +337,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "EnableJoinLeaveMessageByDefault": "false", }, + nil, }, { "test key for GiphySdkKey", @@ -331,6 +351,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "GiphySdkKey": model.ServiceSettingsDefaultGiphySdkKeyTest, }, + nil, }, { "report a problem values", @@ -350,6 +371,7 @@ func TestGetClientConfig(t *testing.T) { "ReportAProblemMail": "mail", "AllowDownloadLogs": "true", }, + nil, }, { "access control settings enabled", @@ -365,6 +387,7 @@ func TestGetClientConfig(t *testing.T) { "EnableAttributeBasedAccessControl": "true", "EnableUserManagedAttributes": "true", }, + nil, }, { "access control settings disabled", @@ -380,6 +403,7 @@ func TestGetClientConfig(t *testing.T) { "EnableAttributeBasedAccessControl": "false", "EnableUserManagedAttributes": "false", }, + nil, }, { "access control settings default", @@ -390,6 +414,7 @@ func TestGetClientConfig(t *testing.T) { "EnableAttributeBasedAccessControl": "false", "EnableUserManagedAttributes": "false", }, + nil, }, { "burn on read enabled", @@ -405,6 +430,7 @@ func TestGetClientConfig(t *testing.T) { "EnableBurnOnRead": "true", "BurnOnReadDurationSeconds": "1800", }, + nil, }, { "burn on read disabled", @@ -420,6 +446,7 @@ func TestGetClientConfig(t *testing.T) { "EnableBurnOnRead": "false", "BurnOnReadDurationSeconds": "600", }, + nil, }, { "burn on read default", @@ -430,6 +457,7 @@ func TestGetClientConfig(t *testing.T) { "EnableBurnOnRead": "true", "BurnOnReadDurationSeconds": "600", // 10 minutes in seconds }, + nil, }, { "mobile watermark uses experimental settings", @@ -446,6 +474,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "ExperimentalEnableWatermark": "true", }, + nil, }, { "Intune MAM enabled with Enterprise Advanced license and Office365 AuthService", @@ -466,6 +495,7 @@ func TestGetClientConfig(t *testing.T) { "IntuneMAMEnabled": "true", "IntuneScope": "api://87654321-4321-4321-4321-210987654321/login.mattermost", }, + nil, }, { "Intune MAM disabled when not enabled", @@ -485,6 +515,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IntuneMAMEnabled": "false", }, + nil, }, { "Intune MAM disabled when TenantId is missing", @@ -504,6 +535,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IntuneMAMEnabled": "false", }, + nil, }, { "Intune MAM disabled when ClientId is missing", @@ -523,6 +555,7 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IntuneMAMEnabled": "false", }, + nil, }, { "Intune MAM not exposed with lower license tier", @@ -540,6 +573,7 @@ func TestGetClientConfig(t *testing.T) { SkuShortName: model.LicenseShortSkuProfessional, }, map[string]string{}, + []string{"IntuneMAMEnabled", "IntuneScope"}, }, { "Intune MAM not exposed without license", @@ -554,6 +588,7 @@ func TestGetClientConfig(t *testing.T) { "", nil, map[string]string{}, + []string{"IntuneMAMEnabled", "IntuneScope"}, }, { "Intune MAM enabled with Enterprise Advanced license and SAML AuthService", @@ -578,6 +613,7 @@ func TestGetClientConfig(t *testing.T) { "IntuneScope": "api://87654321-4321-4321-4321-210987654321/login.mattermost", "IntuneAuthService": "saml", }, + nil, }, { "Intune MAM disabled when AuthService is missing", @@ -597,6 +633,100 @@ func TestGetClientConfig(t *testing.T) { map[string]string{ "IntuneMAMEnabled": "false", }, + nil, + }, + { + "Mobile Ephemeral Mode enabled with custom values", + &model.Config{ + FeatureFlags: &model.FeatureFlags{MobileEphemeralMode: true}, + MobileEphemeralModeSettings: model.MobileEphemeralModeSettings{ + Enable: model.NewPointer(true), + DisconnectionTimeoutSeconds: model.NewPointer(120), + OfflinePersistenceTimerHours: model.NewPointer(48), + AutoCacheCleanupDays: model.NewPointer(14), + }, + }, + "", + &model.License{ + Features: &model.Features{}, + SkuShortName: model.LicenseShortSkuEnterpriseAdvanced, + }, + map[string]string{ + "MobileEphemeralModeEnabled": "true", + "MobileEphemeralModeDisconnectionTimeoutSeconds": "120", + "MobileEphemeralModeOfflinePersistenceTimerHours": "48", + "MobileEphemeralModeAutoCacheCleanupDays": "14", + }, + nil, + }, + { + "Mobile Ephemeral Mode disabled still exposes parameters", + &model.Config{ + FeatureFlags: &model.FeatureFlags{MobileEphemeralMode: true}, + MobileEphemeralModeSettings: model.MobileEphemeralModeSettings{ + Enable: model.NewPointer(false), + DisconnectionTimeoutSeconds: model.NewPointer(60), + OfflinePersistenceTimerHours: model.NewPointer(24), + AutoCacheCleanupDays: model.NewPointer(7), + }, + }, + "", + &model.License{ + Features: &model.Features{}, + SkuShortName: model.LicenseShortSkuEnterpriseAdvanced, + }, + map[string]string{ + "MobileEphemeralModeEnabled": "false", + "MobileEphemeralModeDisconnectionTimeoutSeconds": "60", + "MobileEphemeralModeOfflinePersistenceTimerHours": "24", + "MobileEphemeralModeAutoCacheCleanupDays": "7", + }, + nil, + }, + { + "Mobile Ephemeral Mode not exposed when feature flag is off", + &model.Config{ + FeatureFlags: &model.FeatureFlags{MobileEphemeralMode: false}, + MobileEphemeralModeSettings: model.MobileEphemeralModeSettings{ + Enable: model.NewPointer(true), + }, + }, + "", + &model.License{ + Features: &model.Features{}, + SkuShortName: model.LicenseShortSkuEnterpriseAdvanced, + }, + map[string]string{}, + []string{"MobileEphemeralModeEnabled", "MobileEphemeralModeDisconnectionTimeoutSeconds", "MobileEphemeralModeOfflinePersistenceTimerHours", "MobileEphemeralModeAutoCacheCleanupDays"}, + }, + { + "Mobile Ephemeral Mode not exposed without license", + &model.Config{ + FeatureFlags: &model.FeatureFlags{MobileEphemeralMode: true}, + MobileEphemeralModeSettings: model.MobileEphemeralModeSettings{ + Enable: model.NewPointer(true), + }, + }, + "", + nil, + map[string]string{}, + []string{"MobileEphemeralModeEnabled", "MobileEphemeralModeDisconnectionTimeoutSeconds", "MobileEphemeralModeOfflinePersistenceTimerHours", "MobileEphemeralModeAutoCacheCleanupDays"}, + }, + { + "Mobile Ephemeral Mode not exposed with lower license tier", + &model.Config{ + FeatureFlags: &model.FeatureFlags{MobileEphemeralMode: true}, + MobileEphemeralModeSettings: model.MobileEphemeralModeSettings{ + Enable: model.NewPointer(true), + }, + }, + "", + &model.License{ + Features: &model.Features{}, + SkuShortName: model.LicenseShortSkuProfessional, + }, + map[string]string{}, + []string{"MobileEphemeralModeEnabled", "MobileEphemeralModeDisconnectionTimeoutSeconds", "MobileEphemeralModeOfflinePersistenceTimerHours", "MobileEphemeralModeAutoCacheCleanupDays"}, }, } @@ -616,6 +746,10 @@ func TestGetClientConfig(t *testing.T) { assert.Equal(t, expectedValue, actualValue) } } + for _, absentField := range testCase.absentFields { + _, ok := configMap[absentField] + assert.False(t, ok, fmt.Sprintf("config should not contain %v", absentField)) + } }) } } diff --git a/server/i18n/en.json b/server/i18n/en.json index 01bb249870f..1f58855805c 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -11542,6 +11542,18 @@ "id": "model.config.is_valid.minimum_desktop_app_version.app_error", "translation": "Invalid version number. Must be a valid semantic version (e.g. 5.0.0)." }, + { + "id": "model.config.is_valid.mobile_ephemeral_mode.auto_cache_cleanup.app_error", + "translation": "Invalid Auto Cache Cleanup value. Must be between {{.Min}} and {{.Max}} days." + }, + { + "id": "model.config.is_valid.mobile_ephemeral_mode.disconnection_timeout.app_error", + "translation": "Invalid Disconnection Timeout value. Must be between {{.Min}} and {{.Max}} seconds." + }, + { + "id": "model.config.is_valid.mobile_ephemeral_mode.offline_persistence.app_error", + "translation": "Invalid Offline Persistence Timer value. Must be between {{.Min}} and {{.Max}} hours." + }, { "id": "model.config.is_valid.move_thread.domain_invalid.app_error", "translation": "Invalid domain for move thread settings" diff --git a/server/public/model/config.go b/server/public/model/config.go index e04fe346529..f001937219e 100644 --- a/server/public/model/config.go +++ b/server/public/model/config.go @@ -3442,6 +3442,61 @@ func (s *DataRetentionSettings) GetFileRetentionHours() int { return DataRetentionSettingsDefaultFileRetentionDays * 24 } +const ( + MobileEphemeralModeDefaultDisconnectionTimeoutSeconds = 60 + MobileEphemeralModeDefaultOfflinePersistenceTimerHours = 24 + MobileEphemeralModeDefaultAutoCacheCleanupDays = 7 + + MobileEphemeralModeMaxDisconnectionTimeoutSeconds = 600 + MobileEphemeralModeMaxOfflinePersistenceTimerHours = 72 + MobileEphemeralModeMaxAutoCacheCleanupDays = 60 +) + +type MobileEphemeralModeSettings struct { + Enable *bool `access:"environment_mobile_security"` + DisconnectionTimeoutSeconds *int `access:"environment_mobile_security"` + OfflinePersistenceTimerHours *int `access:"environment_mobile_security"` + AutoCacheCleanupDays *int `access:"environment_mobile_security"` +} + +func (s *MobileEphemeralModeSettings) SetDefaults() { + if s.Enable == nil { + s.Enable = NewPointer(false) + } + if s.DisconnectionTimeoutSeconds == nil { + s.DisconnectionTimeoutSeconds = NewPointer(MobileEphemeralModeDefaultDisconnectionTimeoutSeconds) + } + if s.OfflinePersistenceTimerHours == nil { + s.OfflinePersistenceTimerHours = NewPointer(MobileEphemeralModeDefaultOfflinePersistenceTimerHours) + } + if s.AutoCacheCleanupDays == nil { + s.AutoCacheCleanupDays = NewPointer(MobileEphemeralModeDefaultAutoCacheCleanupDays) + } +} + +func (s *MobileEphemeralModeSettings) isValid() *AppError { + if s.Enable == nil || !*s.Enable { + return nil + } + + if s.DisconnectionTimeoutSeconds == nil || *s.DisconnectionTimeoutSeconds < 0 || *s.DisconnectionTimeoutSeconds > MobileEphemeralModeMaxDisconnectionTimeoutSeconds { + return NewAppError("Config.IsValid", "model.config.is_valid.mobile_ephemeral_mode.disconnection_timeout.app_error", + map[string]any{"Min": 0, "Max": MobileEphemeralModeMaxDisconnectionTimeoutSeconds}, "", http.StatusBadRequest) + } + + if s.OfflinePersistenceTimerHours == nil || *s.OfflinePersistenceTimerHours < 0 || *s.OfflinePersistenceTimerHours > MobileEphemeralModeMaxOfflinePersistenceTimerHours { + return NewAppError("Config.IsValid", "model.config.is_valid.mobile_ephemeral_mode.offline_persistence.app_error", + map[string]any{"Min": 0, "Max": MobileEphemeralModeMaxOfflinePersistenceTimerHours}, "", http.StatusBadRequest) + } + + if s.AutoCacheCleanupDays == nil || *s.AutoCacheCleanupDays < 0 || *s.AutoCacheCleanupDays > MobileEphemeralModeMaxAutoCacheCleanupDays { + return NewAppError("Config.IsValid", "model.config.is_valid.mobile_ephemeral_mode.auto_cache_cleanup.app_error", + map[string]any{"Min": 0, "Max": MobileEphemeralModeMaxAutoCacheCleanupDays}, "", http.StatusBadRequest) + } + + return nil +} + type JobSettings struct { RunJobs *bool `access:"write_restrictable,cloud_restrictable"` // telemetry: none RunScheduler *bool `access:"write_restrictable,cloud_restrictable"` // telemetry: none @@ -4079,6 +4134,7 @@ type Config struct { AnalyticsSettings AnalyticsSettings ElasticsearchSettings ElasticsearchSettings DataRetentionSettings DataRetentionSettings + MobileEphemeralModeSettings MobileEphemeralModeSettings MessageExportSettings MessageExportSettings JobSettings JobSettings PluginSettings PluginSettings @@ -4194,6 +4250,7 @@ func (o *Config) SetDefaults() { o.NativeAppSettings.SetDefaults() o.IntuneSettings.SetDefaults() o.DataRetentionSettings.SetDefaults() + o.MobileEphemeralModeSettings.SetDefaults() o.RateLimitSettings.SetDefaults() o.LogSettings.SetDefaults() o.ExperimentalAuditSettings.SetDefaults() @@ -4373,6 +4430,10 @@ func (o *Config) IsValid() *AppError { return appErr } + if appErr := o.MobileEphemeralModeSettings.isValid(); appErr != nil { + return appErr + } + if appErr := o.GuestAccountsSettings.IsValid(); appErr != nil { return appErr } diff --git a/server/public/model/config_test.go b/server/public/model/config_test.go index 41cf045f181..96859087064 100644 --- a/server/public/model/config_test.go +++ b/server/public/model/config_test.go @@ -2975,6 +2975,121 @@ func TestConfigAccessTagsMapToValidPermissions(t *testing.T) { checkStruct(t, reflect.TypeFor[Config](), "Config") } +func TestMobileEphemeralModeSettingsDefaults(t *testing.T) { + c := Config{} + c.SetDefaults() + + require.False(t, *c.MobileEphemeralModeSettings.Enable) + require.Equal(t, MobileEphemeralModeDefaultDisconnectionTimeoutSeconds, *c.MobileEphemeralModeSettings.DisconnectionTimeoutSeconds) + require.Equal(t, MobileEphemeralModeDefaultOfflinePersistenceTimerHours, *c.MobileEphemeralModeSettings.OfflinePersistenceTimerHours) + require.Equal(t, MobileEphemeralModeDefaultAutoCacheCleanupDays, *c.MobileEphemeralModeSettings.AutoCacheCleanupDays) +} + +func TestMobileEphemeralModeSettingsIsValid(t *testing.T) { + testCases := []struct { + name string + settings MobileEphemeralModeSettings + expectError bool + errorId string + }{ + { + name: "disabled settings should be valid", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(false), + }, + expectError: false, + }, + { + name: "enabled with valid values", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(120), + OfflinePersistenceTimerHours: NewPointer(24), + AutoCacheCleanupDays: NewPointer(7), + }, + expectError: false, + }, + { + name: "invalid disconnection timeout above max", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(MobileEphemeralModeMaxDisconnectionTimeoutSeconds + 1), + OfflinePersistenceTimerHours: NewPointer(0), + AutoCacheCleanupDays: NewPointer(0), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.disconnection_timeout.app_error", + }, + { + name: "invalid offline persistence above max", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(60), + OfflinePersistenceTimerHours: NewPointer(MobileEphemeralModeMaxOfflinePersistenceTimerHours + 1), + AutoCacheCleanupDays: NewPointer(0), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.offline_persistence.app_error", + }, + { + name: "invalid auto cache cleanup above max", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(60), + OfflinePersistenceTimerHours: NewPointer(0), + AutoCacheCleanupDays: NewPointer(MobileEphemeralModeMaxAutoCacheCleanupDays + 1), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.auto_cache_cleanup.app_error", + }, + { + name: "invalid negative disconnection timeout", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(-1), + OfflinePersistenceTimerHours: NewPointer(0), + AutoCacheCleanupDays: NewPointer(0), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.disconnection_timeout.app_error", + }, + { + name: "invalid negative offline persistence", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(60), + OfflinePersistenceTimerHours: NewPointer(-1), + AutoCacheCleanupDays: NewPointer(0), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.offline_persistence.app_error", + }, + { + name: "invalid negative auto cache cleanup", + settings: MobileEphemeralModeSettings{ + Enable: NewPointer(true), + DisconnectionTimeoutSeconds: NewPointer(60), + OfflinePersistenceTimerHours: NewPointer(0), + AutoCacheCleanupDays: NewPointer(-1), + }, + expectError: true, + errorId: "model.config.is_valid.mobile_ephemeral_mode.auto_cache_cleanup.app_error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.settings.isValid() + if tc.expectError { + require.NotNil(t, err) + require.Equal(t, tc.errorId, err.Id) + } else { + require.Nil(t, err) + } + }) + } +} + func TestNativeAppSettingsIsValid(t *testing.T) { t.Run("defaults are valid", func(t *testing.T) { cfg := Config{} diff --git a/server/public/model/feature_flags.go b/server/public/model/feature_flags.go index 68465f32025..3b0e698d168 100644 --- a/server/public/model/feature_flags.go +++ b/server/public/model/feature_flags.go @@ -125,6 +125,9 @@ type FeatureFlags struct { // Gates the per-channel Discoverable toggle and the channel-join-request flow that lets // non-members find a private channel in Browse Channels and request to join it. DiscoverableChannels bool + + // Enable Mobile Ephemeral Mode for controlling data persistence on mobile devices + MobileEphemeralMode bool } func (f *FeatureFlags) SetDefaults() { @@ -183,6 +186,8 @@ func (f *FeatureFlags) SetDefaults() { f.ManagedChannelCategories = false f.DiscoverableChannels = false + + f.MobileEphemeralMode = false } // ToMap returns the feature flags as a map[string]string diff --git a/webapp/channels/src/components/admin_console/admin_definition.tsx b/webapp/channels/src/components/admin_console/admin_definition.tsx index a62ea3cd26c..bcf95f93107 100644 --- a/webapp/channels/src/components/admin_console/admin_definition.tsx +++ b/webapp/channels/src/components/admin_console/admin_definition.tsx @@ -2373,6 +2373,76 @@ const AdminDefinition: AdminDefinitionType = { }, ], }, + { + key: 'MobileSecuritySettings.EphemeralMode', + title: 'Mobile Ephemeral Mode', + description: defineMessage({id: 'admin.mobileSecurity.sections.ephemeralMode.description', defaultMessage: 'Configure data persistence and cache management policies for mobile devices.'}), + license_sku: LicenseSkus.EnterpriseAdvanced, + component: LicensedSectionContainer, + componentProps: { + requiredSku: LicenseSkus.EnterpriseAdvanced, + featureDiscoveryConfig: { + featureName: 'mobile_ephemeral_mode', + title: defineMessage({id: 'admin.mobileSecurity.ephemeralMode_feature_discovery.title', defaultMessage: 'Control mobile data persistence with Mobile Ephemeral Mode'}), + description: defineMessage({id: 'admin.mobileSecurity.ephemeralMode_feature_discovery.description', defaultMessage: 'With Mattermost Enterprise Advanced, you can enable Mobile Ephemeral Mode to enforce data persistence policies on mobile devices. Configure disconnection timeouts, offline data retention, and automatic cache cleanup.'}), + learnMoreURL: 'https://docs.mattermost.com', + }, + }, + isHidden: it.configIsFalse('FeatureFlags', 'MobileEphemeralMode'), + settings: [ + { + type: 'banner', + label: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.banner', defaultMessage: 'Changes to these settings are delivered to connected devices in real time. Offline devices will continue operating under their last-known settings until they re-establish a server connection. Timer state persists across app and device restarts.'}), + banner_type: 'info', + }, + { + type: 'bool', + key: 'MobileEphemeralModeSettings.Enable', + label: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.enableTitle', defaultMessage: 'Enable Mobile Ephemeral Mode:'}), + help_text: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.enableDescription', defaultMessage: 'When enabled, mobile clients will follow the server-configured ephemeral data policies. Disconnected devices will clean up cached data based on the configured timers.'}), + }, + { + type: 'number', + key: 'MobileEphemeralModeSettings.DisconnectionTimeoutSeconds', + label: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.disconnectionTimeoutTitle', defaultMessage: 'Disconnection Timeout (seconds):'}), + help_text: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.disconnectionTimeoutDescription', defaultMessage: 'Grace period after losing server connection before the device is considered offline. Helps avoid false triggers from brief network interruptions. Values below 5 are not recommended.'}), + placeholder: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.disconnectionTimeout.placeholder', defaultMessage: 'E.g.: 60'}), + isDisabled: it.stateIsFalse('MobileEphemeralModeSettings.Enable'), + validate: validators.numberInRange(0, 600, defineMessage({ + id: 'admin.mobileSecurity.ephemeralMode.disconnectionTimeout.range', + defaultMessage: 'Must be a number between 0 and 600 seconds (10 minutes).', + })), + }, + { + type: 'number', + key: 'MobileEphemeralModeSettings.OfflinePersistenceTimerHours', + label: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.offlinePersistenceTitle', defaultMessage: 'Offline Persistence Timer (hours):'}), + help_text: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.offlinePersistenceDescription', defaultMessage: 'How long cached content is kept after the device goes offline. When the timer expires, cached content is deleted but session credentials are preserved. Set to 0 for immediate cleanup.'}), + disabled_help_text: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.offlinePersistence.disabled', defaultMessage: 'How long cached content is kept after the device goes offline. When the timer expires, cached content is deleted but session credentials are preserved. Set to 0 for immediate cleanup. Requires Mobile Ephemeral Mode to be enabled and Auto Cache Cleanup to be greater than 0.'}), + placeholder: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.offlinePersistence.placeholder', defaultMessage: 'E.g.: 24'}), + isDisabled: it.any( + it.stateIsFalse('MobileEphemeralModeSettings.Enable'), + it.stateEquals('MobileEphemeralModeSettings.AutoCacheCleanupDays', 0), + ), + validate: validators.numberInRange(0, 72, defineMessage({ + id: 'admin.mobileSecurity.ephemeralMode.offlinePersistence.range', + defaultMessage: 'Must be a number between 0 and 72 hours (3 days).', + })), + }, + { + type: 'number', + key: 'MobileEphemeralModeSettings.AutoCacheCleanupDays', + label: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.autoCacheCleanupTitle', defaultMessage: 'Auto Cache Cleanup (days):'}), + help_text: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.autoCacheCleanupDescription', defaultMessage: 'Controls the maximum age of any content cached on the device, regardless of connection status. Prevents unbounded accumulation of sensitive data. Set to 0 for zero-persistence mode where content is never persisted to disk.'}), + placeholder: defineMessage({id: 'admin.mobileSecurity.ephemeralMode.autoCacheCleanup.placeholder', defaultMessage: 'E.g.: 7'}), + isDisabled: it.stateIsFalse('MobileEphemeralModeSettings.Enable'), + validate: validators.numberInRange(0, 60, defineMessage({ + id: 'admin.mobileSecurity.ephemeralMode.autoCacheCleanup.range', + defaultMessage: 'Must be a number between 0 and 60 days.', + })), + }, + ], + }, ], }, }, diff --git a/webapp/channels/src/components/admin_console/admin_definition_helpers.test.tsx b/webapp/channels/src/components/admin_console/admin_definition_helpers.test.tsx index 963b2279b14..91a51483d0d 100644 --- a/webapp/channels/src/components/admin_console/admin_definition_helpers.test.tsx +++ b/webapp/channels/src/components/admin_console/admin_definition_helpers.test.tsx @@ -1,7 +1,7 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {it} from './admin_definition_helpers'; +import {it, validators} from './admin_definition_helpers'; describe('AdminDefinitionHelpers - stateEqualsOrDefault', () => { test('should return true when state value equals expected value', () => { @@ -45,3 +45,22 @@ describe('AdminDefinitionHelpers - stateEqualsOrDefault', () => { expect(checker({}, undefinedStateWithDifferentExpected)).toBe(false); }); }); + +describe('AdminDefinitionHelpers - validators.numberInRange', () => { + const validate = validators.numberInRange(0, 60, 'out of range'); + + test('should return valid for in-range numbers', () => { + expect(validate(0).isValid()).toBe(true); + expect(validate(30).isValid()).toBe(true); + expect(validate(60).isValid()).toBe(true); + }); + + test('should return invalid for out-of-range numbers', () => { + expect(validate(-1).isValid()).toBe(false); + expect(validate(61).isValid()).toBe(false); + }); + + test('should return valid for NaN since the server backfills empty inputs with defaults', () => { + expect(validate(NaN).isValid()).toBe(true); + }); +}); diff --git a/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx b/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx index e4c15053bd5..ca914b01e5d 100644 --- a/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx +++ b/webapp/channels/src/components/admin_console/admin_definition_helpers.tsx @@ -74,6 +74,7 @@ export const validators = { isRequired: (text: MessageDescriptor | string) => (value: string) => new ValidationResult(Boolean(value), text), minValue: (min: number, text: MessageDescriptor | string) => (value: number) => new ValidationResult((value >= min), text), maxValue: (max: number, text: MessageDescriptor | string) => (value: number) => new ValidationResult((value <= max), text), + numberInRange: (min: number, max: number, text: MessageDescriptor | string) => (value: number) => new ValidationResult(Number.isNaN(value) || (value >= min && value <= max), text), }; export const usesLegacyOauth = (config: Partial, state: any, license?: ClientLicense, enterpriseReady?: boolean, consoleAccess?: ConsoleAccess, cloud?: CloudState) => { diff --git a/webapp/channels/src/components/admin_console/admin_definition_mobile_ephemeral_mode.test.tsx b/webapp/channels/src/components/admin_console/admin_definition_mobile_ephemeral_mode.test.tsx new file mode 100644 index 00000000000..13691dfff88 --- /dev/null +++ b/webapp/channels/src/components/admin_console/admin_definition_mobile_ephemeral_mode.test.tsx @@ -0,0 +1,129 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {AdminConfig} from '@mattermost/types/config'; + +import {LicenseSkus} from 'utils/constants'; + +import AdminDefinition from './admin_definition'; +import type {AdminDefinitionSetting, AdminDefinitionConfigSchemaSection} from './types'; + +describe('AdminDefinition - Mobile Ephemeral Mode Settings', () => { + const getEphemeralModeSections = () => { + const mobileSecuritySection = AdminDefinition.environment.subsections.mobile_security; + const sections = 'sections' in mobileSecuritySection.schema! ? mobileSecuritySection.schema.sections : undefined; + return sections; + }; + + const getEphemeralModeSection = () => { + const sections = getEphemeralModeSections(); + return sections?.find((section: AdminDefinitionConfigSchemaSection) => section.key === 'MobileSecuritySettings.EphemeralMode'); + }; + + const getEphemeralModeSettings = () => { + const section = getEphemeralModeSection(); + return section?.settings || []; + }; + + test('should include Mobile Ephemeral Mode section in mobile_security', () => { + const section = getEphemeralModeSection(); + expect(section).toBeDefined(); + }); + + test('should include Enable setting', () => { + const settings = getEphemeralModeSettings(); + const enableSetting = settings.find((s: AdminDefinitionSetting) => s.key === 'MobileEphemeralModeSettings.Enable'); + + expect(enableSetting).toBeDefined(); + expect(enableSetting?.type).toBe('bool'); + expect(enableSetting?.label).toBeDefined(); + expect(enableSetting?.help_text).toBeDefined(); + }); + + test('should include info banner', () => { + const settings = getEphemeralModeSettings(); + const bannerSetting = settings.find((s: AdminDefinitionSetting) => s.type === 'banner'); + + expect(bannerSetting).toBeDefined(); + }); + + test('settings should have proper translation message descriptors', () => { + const settings = getEphemeralModeSettings(); + const settingsWithLabels = settings.filter((s: AdminDefinitionSetting) => s.key?.includes('MobileEphemeralMode')); + + settingsWithLabels.forEach((setting: AdminDefinitionSetting) => { + if (setting.label && typeof setting.label === 'object') { + expect('id' in setting.label).toBe(true); + expect('defaultMessage' in setting.label).toBe(true); + } + + if (setting.help_text && typeof setting.help_text === 'object' && !('$$typeof' in setting.help_text)) { + expect('id' in setting.help_text).toBe(true); + expect('defaultMessage' in setting.help_text).toBe(true); + } + }); + }); + + test('should use LicensedSectionContainer with Enterprise Advanced', () => { + const section = getEphemeralModeSection(); + + expect(section?.component).toBeDefined(); + expect(section?.license_sku).toBe(LicenseSkus.EnterpriseAdvanced); + expect(section?.componentProps).toBeDefined(); + expect(section?.componentProps?.requiredSku).toBe(LicenseSkus.EnterpriseAdvanced); + expect(section?.componentProps?.featureDiscoveryConfig).toBeDefined(); + expect(section?.componentProps?.featureDiscoveryConfig?.featureName).toBe('mobile_ephemeral_mode'); + }); + + test('isHidden should return true when feature flag is disabled', () => { + const section = getEphemeralModeSection(); + expect(section?.isHidden).toBeDefined(); + expect(typeof section?.isHidden).toBe('function'); + + const mockConfig: Partial = {FeatureFlags: {MobileEphemeralMode: false}}; + const isHiddenFn = section!.isHidden as (config: Partial) => boolean; + expect(isHiddenFn(mockConfig)).toBe(true); + }); + + test('isHidden should return false when feature flag is enabled', () => { + const section = getEphemeralModeSection(); + + const mockConfig: Partial = {FeatureFlags: {MobileEphemeralMode: true}}; + const isHiddenFn = section!.isHidden as (config: Partial) => boolean; + expect(isHiddenFn(mockConfig)).toBe(false); + }); + + test('should include DisconnectionTimeoutSeconds number setting', () => { + const settings = getEphemeralModeSettings(); + const setting = settings.find((s: AdminDefinitionSetting) => s.key === 'MobileEphemeralModeSettings.DisconnectionTimeoutSeconds'); + + expect(setting).toBeDefined(); + expect(setting?.type).toBe('number'); + expect(setting?.isDisabled).toBeDefined(); + }); + + test('should include OfflinePersistenceTimerHours number setting', () => { + const settings = getEphemeralModeSettings(); + const setting = settings.find((s: AdminDefinitionSetting) => s.key === 'MobileEphemeralModeSettings.OfflinePersistenceTimerHours'); + + expect(setting).toBeDefined(); + expect(setting?.type).toBe('number'); + expect(setting?.isDisabled).toBeDefined(); + }); + + test('should include AutoCacheCleanupDays number setting', () => { + const settings = getEphemeralModeSettings(); + const setting = settings.find((s: AdminDefinitionSetting) => s.key === 'MobileEphemeralModeSettings.AutoCacheCleanupDays'); + + expect(setting).toBeDefined(); + expect(setting?.type).toBe('number'); + expect(setting?.isDisabled).toBeDefined(); + }); + + test('OfflinePersistenceTimerHours should have disabled_help_text for zero-persistence mode', () => { + const settings = getEphemeralModeSettings(); + const setting = settings.find((s: AdminDefinitionSetting) => s.key === 'MobileEphemeralModeSettings.OfflinePersistenceTimerHours'); + + expect(setting?.disabled_help_text).toBeDefined(); + }); +}); diff --git a/webapp/channels/src/i18n/en.json b/webapp/channels/src/i18n/en.json index 65f3402ac17..cc13c1b910c 100644 --- a/webapp/channels/src/i18n/en.json +++ b/webapp/channels/src/i18n/en.json @@ -1870,11 +1870,30 @@ "admin.mobileSecurity.allowPdfLinkNavigationTitle": "Allow Link Navigation in Secure PDFs:", "admin.mobileSecurity.biometricsDescription": "Enforces biometric authentication (with PIN/passcode fallback) before accessing the app. Users will be prompted based on session activity and server switching rules.", "admin.mobileSecurity.biometricsTitle": "Enable Biometric Authentication:", + "admin.mobileSecurity.ephemeralMode_feature_discovery.description": "With Mattermost Enterprise Advanced, you can enable Mobile Ephemeral Mode to enforce data persistence policies on mobile devices. Configure disconnection timeouts, offline data retention, and automatic cache cleanup.", + "admin.mobileSecurity.ephemeralMode_feature_discovery.title": "Control mobile data persistence with Mobile Ephemeral Mode", + "admin.mobileSecurity.ephemeralMode.autoCacheCleanup.placeholder": "E.g.: 7", + "admin.mobileSecurity.ephemeralMode.autoCacheCleanup.range": "Must be a number between 0 and 60 days.", + "admin.mobileSecurity.ephemeralMode.autoCacheCleanupDescription": "Controls the maximum age of any content cached on the device, regardless of connection status. Prevents unbounded accumulation of sensitive data. Set to 0 for zero-persistence mode where content is never persisted to disk.", + "admin.mobileSecurity.ephemeralMode.autoCacheCleanupTitle": "Auto Cache Cleanup (days):", + "admin.mobileSecurity.ephemeralMode.banner": "Changes to these settings are delivered to connected devices in real time. Offline devices will continue operating under their last-known settings until they re-establish a server connection. Timer state persists across app and device restarts.", + "admin.mobileSecurity.ephemeralMode.disconnectionTimeout.placeholder": "E.g.: 60", + "admin.mobileSecurity.ephemeralMode.disconnectionTimeout.range": "Must be a number between 0 and 600 seconds (10 minutes).", + "admin.mobileSecurity.ephemeralMode.disconnectionTimeoutDescription": "Grace period after losing server connection before the device is considered offline. Helps avoid false triggers from brief network interruptions. Values below 5 are not recommended.", + "admin.mobileSecurity.ephemeralMode.disconnectionTimeoutTitle": "Disconnection Timeout (seconds):", + "admin.mobileSecurity.ephemeralMode.enableDescription": "When enabled, mobile clients will follow the server-configured ephemeral data policies. Disconnected devices will clean up cached data based on the configured timers.", + "admin.mobileSecurity.ephemeralMode.enableTitle": "Enable Mobile Ephemeral Mode:", + "admin.mobileSecurity.ephemeralMode.offlinePersistence.disabled": "How long cached content is kept after the device goes offline. When the timer expires, cached content is deleted but session credentials are preserved. Set to 0 for immediate cleanup. Requires Mobile Ephemeral Mode to be enabled and Auto Cache Cleanup to be greater than 0.", + "admin.mobileSecurity.ephemeralMode.offlinePersistence.placeholder": "E.g.: 24", + "admin.mobileSecurity.ephemeralMode.offlinePersistence.range": "Must be a number between 0 and 72 hours (3 days).", + "admin.mobileSecurity.ephemeralMode.offlinePersistenceDescription": "How long cached content is kept after the device goes offline. When the timer expires, cached content is deleted but session credentials are preserved. Set to 0 for immediate cleanup.", + "admin.mobileSecurity.ephemeralMode.offlinePersistenceTitle": "Offline Persistence Timer (hours):", "admin.mobileSecurity.jailbreakDescription": "Prevents access to the app on devices detected as jailbroken or rooted. If a device fails the security check, users will be denied access or prompted to switch to a compliant server.", "admin.mobileSecurity.jailbreakTitle": "Enable Jailbreak/Root Protection:", "admin.mobileSecurity.mobileAllowDownloads": "Site Configuration > File Sharing and Downloads > Allow File Downloads on Mobile", "admin.mobileSecurity.screenCaptureDescription": "Blocks screenshots and screen recordings when using the mobile app. Screenshots will appear blank, and screen recordings will blur (iOS) or show a black screen (Android). Also applies when switching apps.", "admin.mobileSecurity.screenCaptureTitle": "Prevent Screen Capture:", + "admin.mobileSecurity.sections.ephemeralMode.description": "Configure data persistence and cache management policies for mobile devices.", "admin.mobileSecurity.sections.general.description": "Configure device security features for the mobile app.", "admin.mobileSecurity.sections.intune.description": "Configure Microsoft Intune Mobile Application Management (MAM) for App Protection Policies.", "admin.mobileSecurity.secureFilePreviewDescription": "Prevents file downloads, previews, and sharing for most file types, even if {mobileAllowDownloads} is enabled. Allows in-app previews for PDFs, videos, and images only. Files are stored temporarily in the app's cache and cannot be exported or shared.", diff --git a/webapp/channels/src/utils/admin_console_index.test.tsx b/webapp/channels/src/utils/admin_console_index.test.tsx index 2da36d664e8..3c83b111f77 100644 --- a/webapp/channels/src/utils/admin_console_index.test.tsx +++ b/webapp/channels/src/utils/admin_console_index.test.tsx @@ -29,8 +29,8 @@ describe('AdminConsoleIndex.generateIndex', () => { expect(idx.search('saml')).toEqual([ 'authentication/saml', 'environment/session_lengths', - 'authentication/email', 'environment/mobile_security', + 'authentication/email', 'experimental/features', ]); expect(idx.search('nginx')).toEqual([ diff --git a/webapp/platform/types/src/config.ts b/webapp/platform/types/src/config.ts index b77216eb20a..c5e6fc8899e 100644 --- a/webapp/platform/types/src/config.ts +++ b/webapp/platform/types/src/config.ts @@ -837,6 +837,13 @@ export type IntuneSettings = { AuthService?: string; }; +export type MobileEphemeralModeSettings = { + Enable: boolean; + DisconnectionTimeoutSeconds: number; + OfflinePersistenceTimerHours: number; + AutoCacheCleanupDays: number; +}; + export type ClusterSettings = { Enable: boolean; ClusterName: string; @@ -1105,6 +1112,7 @@ export type AdminConfig = { AccessControlSettings: AccessControlSettings; ContentFlaggingSettings: ContentFlaggingSettings; AutoTranslationSettings: AutoTranslationSettings; + MobileEphemeralModeSettings: MobileEphemeralModeSettings; }; export type ReplicaLagSetting = { From 23b4d8275bb2d8d8649e67cbd07e8bc564aa58d0 Mon Sep 17 00:00:00 2001 From: Andre Vasconcelos Date: Tue, 19 May 2026 00:05:26 +0300 Subject: [PATCH 27/80] MM-68197 Show classification banners in web and desktop apps (#36490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Classification Markings admin console page Adds a new admin console page under Site Configuration for managing classification markings. This allows system administrators to define classification levels (e.g., UNCLASSIFIED, SECRET, TOP SECRET) with associated colors and rank ordering, which will be used for system-wide and per-channel classification banners. The page includes: - Enable/disable toggle backed by the property field system (field existence = enabled) - Country preset dropdown (US DoD, NATO, UK GSCP, Canada, Australia PSPF) that auto-fills standard classification levels - Editable classification levels table with drag-and-drop reorder, inline text editing, color picker, and delete - Auto-switch to "Custom" preset when levels are manually modified - Confirmation dialog when switching presets would overwrite custom data Also adds: - ClassificationMarkings feature flag (default off) - Generic property field client methods (get/create/patch/delete) for the /api/v4/properties/ endpoints - Enterprise license + feature flag gating on the admin page Co-Authored-By: Claude Opus 4.6 (1M context) * Fix classification markings: add validation, error handling, and system object type - Add "system" as a valid property field object type so the classification markings API calls succeed - Surface load errors instead of silently swallowing them (only suppress 404 for unconfigured state) - Validate before save: require at least one level, non-empty names, and no duplicates - Default to custom preset with empty levels on first open - Add section strings to searchableStrings for admin console search Co-Authored-By: Claude Opus 4.6 (1M context) * Move classification field to CPA group targeting users Store the classification markings property field in the custom_profile_attributes group with object_type 'user' instead of the attributes group with object_type 'system'. Clear target_id for PSAv2 system target compliance and mark the field as admin-managed. Co-Authored-By: Claude Opus 4.6 (1M context) * Stabilize preset option IDs and add danger warning on preset switch Hardcode deterministic IDs for all preset classification levels so switching away and back preserves option IDs, preventing orphaned property values. Compare only level data (not preset label) for change detection so cosmetic preset switches don't trigger false save states. Show a danger modal with red confirm button when changing presets on an existing field, warning about system-wide impact on classified resources. The warning appears once per session then allows frictionless switching. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove system object type from property fields Not needed yet — will be added when system/channel banners are implemented. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix ESLint errors in classification markings admin page Fix import ordering and remove unused generateId import. Co-Authored-By: Claude Opus 4.6 (1M context) * Address CodeRabbit review feedback for classification markings - Register property field API endpoints when ClassificationMarkings flag is enabled (not just IntegratedBoards) to prevent 404s - Preserve preset option IDs when creating a new classification field instead of blanking them with empty strings - Add sysconsole read/write permission constants for classification markings across server and webapp, and wire up resource-level permission checks in the admin definition Co-Authored-By: Claude Opus 4.6 (1M context) * Add rank attribute to classification marking options Co-Authored-By: Claude Opus 4.6 (1M context) * Add classification markings permissions migration and read-only support Add a permissions migration to grant classification markings sysconsole permissions to existing roles on upgrade. Wire up the disabled prop so read-only users can view but not edit classification settings. Register the permission in the Delegated Granular Administration UI. Co-Authored-By: Claude Opus 4.6 (1M context) * Paginate loadField to find classification field beyond first page Co-Authored-By: Claude Opus 4.6 (1M context) * Fix lint errors and warnings in classification markings Co-Authored-By: Claude Opus 4.6 (1M context) * Remove classification markings sysconsole permissions; gate on sysadmin instead Classification markings admin page no longer uses feature-specific read/write permissions. Visibility is gated on license + feature flag, editing is gated on system admin role. This avoids coupling feature-specific permissions to the generic property service. Co-Authored-By: Claude Opus 4.6 (1M context) * Set sysadmin-level permissions on classification markings field creation Co-Authored-By: Claude Opus 4.6 (1M context) * Use stable IDs instead of array indices for classification level operations Switch updateLevel/deleteLevel to identify levels by ID rather than index, sort levels by rank on load, and extract i18n strings. Co-Authored-By: Claude Opus 4.6 (1M context) * Refactor classification markings into extracted helper functions Co-Authored-By: Claude Opus 4.6 (1M context) * Add tests for classification markings admin console feature Add unit and component tests covering: - Pure function tests for detectPreset, optionsToLevels, levelsToOptions, processClassificationField, and fetchClassificationField pagination logic - React component tests for rendering states, validation, and user interactions - Client4 property field method tests for URL construction and HTTP verbs - Server routing test verifying routes register with ClassificationMarkings flag - Feature flag default and serialization test Export pure functions from classification_markings.tsx to enable direct testing. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix lint errors in classification markings tests Co-Authored-By: Claude Opus 4.6 (1M context) * Fix test compilation error * Fix color input auto-filling after 3 hex characters in classification markings Buffer ColorInput onChange in a LevelColorCell wrapper so the table doesn't re-render mid-typing, preventing the input from losing its focus-guarded local state. Co-Authored-By: Claude Opus 4.6 (1M context) * Fixing style issues with color picker z-index * Added fix to prevent immediate dismissal when clicking inside color picker * Adding E2E test suite for configuration * Removing duplicates * Fixing unrelated linter error * Fixing test linting issues * Updating tests to skip appropriately * Matching configuration to UX specs * Fixing style lint * Added informational banner for presentational nature of markings * Enabling the markings flag on playwright server * Added missing feature flag to e2e test environment in ci * Reverting changes to color_input - Not needed as we're using a custom component * Added and polished global banner configuration * Refactoring webapp for readability - Separating components - Adding unit tests - Isolating helper methods into utilities * Fixing linter errors * linter fix * Manually fixing linter issues * Separating global classification component * Added persistence of classification marking configuration * Changing LevelID with LevelName * Making changes for PR reviews * Changing property object of classification field to template * syncing i18n file * Removing inaccurate note from comments * PR fixes for UX review * Cleaning up unused value * Added GlobalClassificationBanner component - Made sure it syncs on change by using normal configuration values on it - Works with "top" and "top_and_bottom" - Renders on both root and admin_console * Adding E2E test cases for global classification * Linter fixes, i18n extract * PR Fixes * Linter fix * Matching default messages * Fixing type errors * Fixing pipeline and runtime errors * Fixing announcementbar rendering on top of global classifications * Increasing banner & font sizes * Fixing font size to 12px instead of 16px - I read it wrong * Replacing config values with property * Test linter fixes * Fixing type errors and go format error * Making changes needed to align with specs - Ensuring system_classification is a separate linked property that differs from the template - Saving the global classification banner values as a propertyvalue * Added missing arguments in e2e tests * Added missing conditions for useEffect - Also fixing E2E error in pipeline * Fixing issues with V1 and V2 group mismatch * Fixes for linter errors and coderabbit review * Addressing more issues found by coderabbit * Fixing issues found by coderabbit * Migrating to use system properties * Ran all linters and prettier - Resolving coding style drift that happened from not running prettier on the webapp (even though CI doesn't check for this) * Undoing the prettier changes in webapp * Cleaning up unwanted autoformatted changes * Reverting prettier changes to clean diff * Fixing E2E test * Import fixes in test * Applying changes for PR feedback * Fixing issues with failing e2e tests * Changing key of selection from name to id * Replacing field setup in E2E tests to use levelId instead of levelName * Added classification setup per channel on channel creation * WIP: Adding classification banner integrated with channel banners - Using a hook to resolve which values should be evaluated when displaying the banner * Fixing style of dropdown input for classifications * Fixing visual issues with dropdown inputs * Adding E2E Tests and linter fixes * General fixes and improvements * Applying linter fixes * Resolving lingering linter issues * Updated snapshot and extracted i18n * Adding test cleanup to prevent failures due to duplicates * Addressing nitpick comment for test mapping of values * Applying more fixes to E2E tests * Improving test coverage and e2e test cleanup * Resolving type issues * Refactoring classification constant names an documentation * Ensuring propertyvalue only stores single id, storing banner text in banner_info * Fixing issues with linter alongside style issues on header * Updating test assertion to account for fallback * Fixing issues found during testing - Removing custom selection from being an option and turned it into a state - Ensuring only system administrators can set channel classification levels * Fixing z-index issue with color input popover * Setting classification level to lowest available value when switching it on * Updating unit tests to match new spec for preselection --------- Co-authored-by: David Krauser Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: David Krauser Co-authored-by: Mattermost Build --- .../channel_classification.spec.ts | 381 ++++++++++++++ .../channel_classification/helpers.ts | 126 +++++ .../sidebar_icon_realtime_update.spec.ts | 30 +- .../classification_markings.spec.ts | 67 +++ .../classification_markings_helpers.ts | 10 + .../channels/src/actions/websocket_actions.ts | 28 +- .../__snapshots__/color_setting.test.tsx.snap | 12 + .../classification_markings.test.tsx | 469 ++++++++++++++++-- .../classification_markings.tsx | 39 +- .../classification_markings/utils/index.ts | 148 ++++-- .../utils/preset_dropdown_styles.ts | 2 +- .../channel_banner/channel_banner.tsx | 31 +- .../channel_settings_configuration_tab.scss | 41 +- ...hannel_settings_configuration_tab.test.tsx | 365 ++++++++++++++ .../channel_settings_configuration_tab.tsx | 255 +++++++++- .../channel_settings_modal.scss | 9 +- .../channels/src/components/color_input.tsx | 32 +- .../useChannelClassificationBanner.test.ts | 300 +++++++++++ .../hooks/useChannelClassificationBanner.ts | 113 +++++ .../hooks/useClassificationMarkings.test.ts | 298 +++++++++++ .../common/hooks/useClassificationMarkings.ts | 110 ++++ .../global_classification_banner.test.tsx | 40 +- .../global_classification_banner.tsx | 47 +- .../new_channel_modal/new_channel_modal.scss | 116 +++++ .../new_channel_modal/new_channel_modal.tsx | 163 +++++- webapp/channels/src/i18n/en.json | 8 + .../src/sass/components/_color-input.scss | 12 +- webapp/platform/client/src/client4.ts | 7 + 28 files changed, 3068 insertions(+), 191 deletions(-) create mode 100644 e2e-tests/playwright/specs/functional/channels/channel_classification/channel_classification.spec.ts create mode 100644 e2e-tests/playwright/specs/functional/channels/channel_classification/helpers.ts create mode 100644 webapp/channels/src/components/common/hooks/useChannelClassificationBanner.test.ts create mode 100644 webapp/channels/src/components/common/hooks/useChannelClassificationBanner.ts create mode 100644 webapp/channels/src/components/common/hooks/useClassificationMarkings.test.ts create mode 100644 webapp/channels/src/components/common/hooks/useClassificationMarkings.ts diff --git a/e2e-tests/playwright/specs/functional/channels/channel_classification/channel_classification.spec.ts b/e2e-tests/playwright/specs/functional/channels/channel_classification/channel_classification.spec.ts new file mode 100644 index 00000000000..9905e302950 --- /dev/null +++ b/e2e-tests/playwright/specs/functional/channels/channel_classification/channel_classification.spec.ts @@ -0,0 +1,381 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +/** + * Channel Classification E2E tests. + * Tests the classification level assignment feature on both new and existing channels. + * + * Prerequisites: Enterprise-tier license + ClassificationMarkings feature flag enabled. + */ + +import {expect, test, getAdminClient, licenseTier} from '@mattermost/playwright-lib'; +import type {PlaywrightExtended} from '@mattermost/playwright-lib'; + +import { + TEST_LEVELS, + setClassificationMarkingsFeatureFlag, + setupClassificationWithChannelField, + deleteClassificationFieldsIfExist, +} from './helpers'; +import type {ClassificationLevel} from './helpers'; + +let classificationLevels: ClassificationLevel[] = []; +let setupComplete = false; + +// Teams created by pw.initSetup() in each test are tracked here and deleted in +// afterEach so local environments don't accumulate stale teams across runs. +const createdTeamIds: string[] = []; + +async function initSetupTracked(pw: PlaywrightExtended) { + const setup = await pw.initSetup(); + createdTeamIds.push(setup.team.id); + return setup; +} + +test.beforeAll(async () => { + const {adminClient} = await getAdminClient(); + const license = await adminClient.getClientLicenseOld(); + if (licenseTier(license.SkuShortName) < 20) { + return; + } + + await setClassificationMarkingsFeatureFlag(adminClient, true); + const setup = await setupClassificationWithChannelField(adminClient); + classificationLevels = setup.levels; + setupComplete = true; +}); + +test.afterAll(async () => { + if (!setupComplete) { + return; + } + const {adminClient} = await getAdminClient(); + try { + await deleteClassificationFieldsIfExist(adminClient); + } catch { + // Best-effort cleanup + } +}); + +test.beforeEach(async () => { + const {adminClient} = await getAdminClient(); + const license = await adminClient.getClientLicenseOld(); + test.skip(licenseTier(license.SkuShortName) < 20, 'Channel classification requires Enterprise-tier license'); + test.skip(!setupComplete, 'Classification levels were not set up'); + + const config = await adminClient.getConfig(); + test.skip( + config.FeatureFlags.ClassificationMarkings !== true, + 'ClassificationMarkings feature flag could not be enabled', + ); +}); + +test.afterEach(async () => { + if (createdTeamIds.length === 0) { + return; + } + const ids = createdTeamIds.splice(0); + try { + const {adminClient} = await getAdminClient({skipLog: true}); + await Promise.allSettled(ids.map((id) => adminClient.deleteTeam(id))); + } catch { + // Best-effort cleanup + } +}); + +test.describe('Channel Classification - New channel creation', () => { + test('Enabling classification toggle without selecting values prevents channel creation', async ({pw}) => { + const {adminUser, team} = await initSetupTracked(pw); + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const newChannelModal = await channelsPage.openNewChannelModal(); + await newChannelModal.fillDisplayName(`test-${pw.random.id()}`); + await newChannelModal.publicTypeButton.click(); + + // Create button should be enabled before toggling classification + await expect(newChannelModal.createButton).toBeEnabled(); + + // Enable classification toggle + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + // Create button should be disabled (no classification level selected, no banner text) + await expect(newChannelModal.createButton).toBeDisabled(); + }); + + test('Classification dropdown displays the correct levels from the template', async ({pw}) => { + const {adminUser, team} = await initSetupTracked(pw); + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const newChannelModal = await channelsPage.openNewChannelModal(); + await newChannelModal.fillDisplayName(`test-${pw.random.id()}`); + + // Enable classification toggle + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + // Open the classification dropdown + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await dropdownContainer.click(); + + // Verify all test levels are present in the dropdown menu + const menu = channelsPage.page.locator('.DropDown__menu'); + await expect(menu).toBeVisible(); + for (const level of TEST_LEVELS) { + await expect(menu.getByText(level.name, {exact: true})).toBeVisible(); + } + }); + + test('User can append text to the Banner Text field after selecting a classification', async ({pw}) => { + const {adminUser, team} = await initSetupTracked(pw); + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const selectedLevel = classificationLevels.find((l) => l.name === 'SECRET'); + expect(selectedLevel).toBeDefined(); + + const newChannelModal = await channelsPage.openNewChannelModal(); + await newChannelModal.fillDisplayName(`test-${pw.random.id()}`); + + // Enable classification toggle + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + // Select a classification level + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await dropdownContainer.click(); + const menu = channelsPage.page.locator('.DropDown__menu'); + await menu.getByText(selectedLevel!.name, {exact: true}).click(); + + // Banner text should be auto-populated with the bold level name + const bannerTextbox = channelsPage.page.locator('#channel_classification_banner_text'); + await expect(bannerTextbox).toBeVisible(); + const currentValue = await bannerTextbox.inputValue(); + expect(currentValue).toContain(selectedLevel!.name); + + // Append custom text to the banner + await bannerTextbox.click(); + await bannerTextbox.press('End'); + await bannerTextbox.pressSequentially(' - Custom suffix'); + + const updatedValue = await bannerTextbox.inputValue(); + expect(updatedValue).toContain('Custom suffix'); + }); + + test('Creating channel with classification shows banner with correct color', async ({pw}) => { + const {adminUser, team} = await initSetupTracked(pw); + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const selectedLevel = classificationLevels.find((l) => l.name === 'SECRET'); + expect(selectedLevel).toBeDefined(); + + const newChannelModal = await channelsPage.openNewChannelModal(); + await newChannelModal.fillDisplayName(`classified-${pw.random.id()}`); + await newChannelModal.publicTypeButton.click(); + + // Enable classification toggle + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + // Select the classification level + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await dropdownContainer.click(); + const menu = channelsPage.page.locator('.DropDown__menu'); + await menu.getByText(selectedLevel!.name, {exact: true}).click(); + + // Wait for banner text to auto-populate, then create the channel + const bannerTextbox = channelsPage.page.locator('#channel_classification_banner_text'); + await expect(bannerTextbox).toBeVisible(); + await expect(bannerTextbox).not.toHaveValue(''); + + await newChannelModal.create(); + + // Should be redirected to the new channel and center view loads + await expect(channelsPage.page).toHaveURL(/\/channels\//, {timeout: 30000}); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 30000}); + + // Channel banner should be visible (allow extra time for property value fetch) + const banner = channelsPage.page.getByTestId('channel_banner_container'); + await expect(banner).toBeVisible({timeout: 30000}); + + // Verify the banner has the correct background color + const actualBackgroundColor = await banner.evaluate((el) => { + return window.getComputedStyle(el).getPropertyValue('background-color'); + }); + const expectedRgb = hexToRgb(selectedLevel!.color); + expect(actualBackgroundColor).toBe(expectedRgb); + + // Verify the banner contains the classification level name (rendered from **SECRET** markdown) + await expect(banner).toContainText(selectedLevel!.name); + }); +}); + +test.describe('Channel Classification - Existing channel settings', () => { + test('Classification toggle can be enabled from channel settings', async ({pw}) => { + const {adminUser, team, adminClient} = await initSetupTracked(pw); + + const channel = await adminClient.createChannel( + pw.random.channel({teamId: team.id, name: `cls-${pw.random.id()}`, displayName: `Cls ${pw.random.id()}`}), + ); + await adminClient.addToChannel(adminUser.id, channel.id); + + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name, channel.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const channelSettingsModal = await channelsPage.openChannelSettings(); + await channelSettingsModal.openConfigurationTab(); + + // The classification toggle should be available + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await expect(classificationToggle).toBeVisible(); + + // Toggle it on + const classes = await classificationToggle.getAttribute('class'); + if (!classes?.includes('active')) { + await classificationToggle.click(); + } + + // Toggle should now be active + await expect(classificationToggle).toHaveClass(/active/); + }); + + test('Classification level can be set once toggle is enabled', async ({pw}) => { + const {adminUser, team, adminClient} = await initSetupTracked(pw); + + const channel = await adminClient.createChannel( + pw.random.channel({teamId: team.id, name: `cls-${pw.random.id()}`, displayName: `Cls ${pw.random.id()}`}), + ); + await adminClient.addToChannel(adminUser.id, channel.id); + + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name, channel.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const channelSettingsModal = await channelsPage.openChannelSettings(); + await channelSettingsModal.openConfigurationTab(); + + // Enable classification toggle + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + // Classification level dropdown should be visible + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await expect(dropdownContainer).toBeVisible(); + + // Open dropdown and select a level + await dropdownContainer.click(); + const menu = channelsPage.page.locator('.DropDown__menu'); + await expect(menu).toBeVisible(); + + const selectedLevel = classificationLevels.find((l) => l.name === 'CONFIDENTIAL'); + expect(selectedLevel).toBeDefined(); + await menu.getByText(selectedLevel!.name, {exact: true}).click(); + + // The dropdown should now show the selected value + await expect(dropdownContainer.getByText(selectedLevel!.name, {exact: true})).toBeVisible(); + }); + + test('Selecting classification locks banner toggle active and disabled, with matching color', async ({pw}) => { + const {adminUser, team, adminClient} = await initSetupTracked(pw); + + const channel = await adminClient.createChannel( + pw.random.channel({teamId: team.id, name: `cls-${pw.random.id()}`, displayName: `Cls ${pw.random.id()}`}), + ); + await adminClient.addToChannel(adminUser.id, channel.id); + + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name, channel.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const channelSettingsModal = await channelsPage.openChannelSettings(); + await channelSettingsModal.openConfigurationTab(); + + // Enable classification and select a level + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await dropdownContainer.click(); + + const selectedLevel = classificationLevels.find((l) => l.name === 'SECRET'); + expect(selectedLevel).toBeDefined(); + const menu = channelsPage.page.locator('.DropDown__menu'); + await menu.getByText(selectedLevel!.name, {exact: true}).click(); + + // The channel banner toggle should now be forced active and disabled + const bannerToggle = channelsPage.page.getByTestId('channelBannerToggle-button'); + await expect(bannerToggle).toBeVisible(); + await expect(bannerToggle).toHaveClass(/active/); + await expect(bannerToggle).toBeDisabled(); + + // Banner color input should be locked to the classification color + const colorInput = channelsPage.page.locator('#channel_banner_banner_background_color_picker-inputColorValue'); + await expect(colorInput).toBeVisible(); + const colorValue = await colorInput.inputValue(); + expect(colorValue.toLowerCase().replace('#', '')).toBe(selectedLevel!.color.toLowerCase().replace('#', '')); + }); + + test('Editing banner text and saving updates the banner in real time', async ({pw}) => { + const {adminUser, team, adminClient} = await initSetupTracked(pw); + + const channel = await adminClient.createChannel( + pw.random.channel({teamId: team.id, name: `cls-${pw.random.id()}`, displayName: `Cls ${pw.random.id()}`}), + ); + await adminClient.addToChannel(adminUser.id, channel.id); + + const {channelsPage} = await pw.testBrowser.login(adminUser); + await channelsPage.goto(team.name, channel.name); + await expect(channelsPage.page.getByTestId('channel_view')).toBeVisible({timeout: 60000}); + + const channelSettingsModal = await channelsPage.openChannelSettings(); + const configurationTab = await channelSettingsModal.openConfigurationTab(); + + // Enable classification and select a level + const classificationToggle = channelsPage.page.getByTestId('channelClassificationToggle-button'); + await classificationToggle.click(); + + const dropdownContainer = channelsPage.page.getByTestId('channelClassificationLevel'); + await dropdownContainer.click(); + + const selectedLevel = classificationLevels.find((l) => l.name === 'TOP SECRET'); + expect(selectedLevel).toBeDefined(); + const menu = channelsPage.page.locator('.DropDown__menu'); + await menu.getByText(selectedLevel!.name, {exact: true}).click(); + + // Edit the banner text to a custom value + const customBannerText = 'TOP SECRET - Handle via COMINT channels only'; + const bannerTextbox = channelsPage.page.locator('#channel_banner_banner_text_textbox'); + await expect(bannerTextbox).toBeVisible(); + await bannerTextbox.fill(customBannerText); + + // Save the changes + await configurationTab.save(); + await channelSettingsModal.close(); + + // The channel banner should now show the custom text with the classification color + const banner = channelsPage.page.getByTestId('channel_banner_container'); + await expect(banner).toBeVisible({timeout: 30000}); + await expect(banner).toContainText(customBannerText); + + const actualBackgroundColor = await banner.evaluate((el) => { + return window.getComputedStyle(el).getPropertyValue('background-color'); + }); + expect(actualBackgroundColor).toBe(hexToRgb(selectedLevel!.color)); + }); +}); + +function hexToRgb(hex: string): string { + const result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); + if (!result) { + return hex; + } + return `rgb(${parseInt(result[1], 16)}, ${parseInt(result[2], 16)}, ${parseInt(result[3], 16)})`; +} diff --git a/e2e-tests/playwright/specs/functional/channels/channel_classification/helpers.ts b/e2e-tests/playwright/specs/functional/channels/channel_classification/helpers.ts new file mode 100644 index 00000000000..17814a7a9fc --- /dev/null +++ b/e2e-tests/playwright/specs/functional/channels/channel_classification/helpers.ts @@ -0,0 +1,126 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import type {Client4} from '@mattermost/client'; + +const PROPERTY_GROUP = 'classification_markings'; +const TEMPLATE_OBJECT_TYPE = 'template'; +const CHANNEL_OBJECT_TYPE = 'channel'; +const TARGET_TYPE = 'system'; +const CLASSIFICATION_FIELD_NAME = 'classification'; +const CHANNEL_LINKED_FIELD_NAME = 'channel_classification'; + +export const TEST_LEVELS = [ + {name: 'UNCLASSIFIED', color: '#007A33', rank: 1}, + {name: 'CONFIDENTIAL', color: '#0033A0', rank: 2}, + {name: 'SECRET', color: '#C8102E', rank: 3}, + {name: 'TOP SECRET', color: '#FF8C00', rank: 4}, +]; + +/** + * Sets the ClassificationMarkings feature flag via server config. + */ +export async function setClassificationMarkingsFeatureFlag(adminClient: Client4, enabled: boolean) { + const config = await adminClient.getConfig(); + await adminClient.updateConfig({ + ...config, + FeatureFlags: { + ...config.FeatureFlags, + ClassificationMarkings: enabled, + }, + } as Awaited>); +} + +/** + * Deletes existing classification fields (channel linked, system linked, and template) + * to provide a clean slate. + */ +export async function deleteClassificationFieldsIfExist(adminClient: Client4) { + // Delete channel linked fields first + try { + const channelFields = await adminClient.getPropertyFields(PROPERTY_GROUP, CHANNEL_OBJECT_TYPE, TARGET_TYPE, ''); + for (const f of channelFields.filter((f) => f.name === CHANNEL_LINKED_FIELD_NAME && f.delete_at === 0)) { + await adminClient.deletePropertyField(PROPERTY_GROUP, CHANNEL_OBJECT_TYPE, f.id); + } + } catch { + // May not exist + } + + // Delete system linked fields + for (const objectType of ['system', 'user'] as const) { + try { + const linkedFields = await adminClient.getPropertyFields(PROPERTY_GROUP, objectType, TARGET_TYPE, ''); + for (const f of linkedFields.filter( + (f) => f.name === 'system_classification' && f.delete_at === 0 && f.linked_field_id, + )) { + await adminClient.deletePropertyField(PROPERTY_GROUP, objectType, f.id); + } + } catch { + // May not exist + } + } + + // Delete template fields + try { + const fields = await adminClient.getPropertyFields(PROPERTY_GROUP, TEMPLATE_OBJECT_TYPE, TARGET_TYPE); + for (const f of fields.filter((f) => f.name === CLASSIFICATION_FIELD_NAME && f.delete_at === 0)) { + await adminClient.deletePropertyField(PROPERTY_GROUP, TEMPLATE_OBJECT_TYPE, f.id); + } + } catch { + // May not exist + } +} + +export type ClassificationLevel = { + id: string; + name: string; + color: string; + rank: number; +}; + +export type SetupResult = { + templateFieldId: string; + channelFieldId: string; + levels: ClassificationLevel[]; +}; + +/** + * Creates the full classification setup: template field + channel linked field. + * Returns the created fields and the resolved levels (with server-assigned IDs). + */ +export async function setupClassificationWithChannelField( + adminClient: Client4, + levels: Array<{name: string; color: string; rank: number}> = TEST_LEVELS, +): Promise { + await deleteClassificationFieldsIfExist(adminClient); + + // Create template field + const templateField = await adminClient.createPropertyField(PROPERTY_GROUP, TEMPLATE_OBJECT_TYPE, { + name: CLASSIFICATION_FIELD_NAME, + type: 'select', + target_type: TARGET_TYPE, + target_id: '', + attrs: { + options: levels.map((l) => ({id: '', name: l.name, color: l.color, rank: l.rank})), + managed: 'admin', + }, + permission_field: 'sysadmin', + permission_values: 'sysadmin', + permission_options: 'sysadmin', + } as Parameters[2]); + + // Create channel linked field + const channelField = await adminClient.createPropertyField(PROPERTY_GROUP, CHANNEL_OBJECT_TYPE, { + name: CHANNEL_LINKED_FIELD_NAME, + type: 'select', + target_type: TARGET_TYPE, + target_id: '', + linked_field_id: templateField.id, + } as Parameters[2]); + + // Resolve levels with server-assigned IDs + const options = (templateField.attrs?.options ?? []) as ClassificationLevel[]; + const resolvedLevels = options.sort((a, b) => a.rank - b.rank); + + return {templateFieldId: templateField.id, channelFieldId: channelField.id, levels: resolvedLevels}; +} diff --git a/e2e-tests/playwright/specs/functional/channels/channel_privacy/sidebar_icon_realtime_update.spec.ts b/e2e-tests/playwright/specs/functional/channels/channel_privacy/sidebar_icon_realtime_update.spec.ts index fec6158ab2d..ea18fcbbf72 100644 --- a/e2e-tests/playwright/specs/functional/channels/channel_privacy/sidebar_icon_realtime_update.spec.ts +++ b/e2e-tests/playwright/specs/functional/channels/channel_privacy/sidebar_icon_realtime_update.spec.ts @@ -1,14 +1,38 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {expect, test} from '@mattermost/playwright-lib'; +import {expect, test, getAdminClient} from '@mattermost/playwright-lib'; +import type {PlaywrightExtended} from '@mattermost/playwright-lib'; + +// Teams created by pw.initSetup() in each test are tracked here and deleted in +// afterEach so local environments don't accumulate stale teams across runs. +const createdTeamIds: string[] = []; + +async function initSetupTracked(pw: PlaywrightExtended) { + const setup = await pw.initSetup(); + createdTeamIds.push(setup.team.id); + return setup; +} + +test.afterEach(async () => { + if (createdTeamIds.length === 0) { + return; + } + const ids = createdTeamIds.splice(0); + try { + const {adminClient} = await getAdminClient({skipLog: true}); + await Promise.allSettled(ids.map((id) => adminClient.deleteTeam(id))); + } catch { + // Best-effort cleanup + } +}); test( 'sidebar icon updates from globe to lock when channel converted to private via API', {tag: ['@channels', '@channel_privacy']}, async ({pw}) => { // # Initialize setup - const {adminClient, user, team} = await pw.initSetup(); + const {adminClient, user, team} = await initSetupTracked(pw); // # Create a public channel const channel = await adminClient.createChannel( @@ -47,7 +71,7 @@ test( {tag: ['@channels', '@channel_privacy']}, async ({pw}) => { // # Initialize setup - const {adminClient, user, team} = await pw.initSetup(); + const {adminClient, user, team} = await initSetupTracked(pw); // # Create a private channel const channel = await adminClient.createChannel( diff --git a/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings.spec.ts b/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings.spec.ts index 57881df8610..0d4dd3e9d03 100644 --- a/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings.spec.ts +++ b/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings.spec.ts @@ -498,6 +498,73 @@ test.describe('System Console - Classification markings', () => { }, ); + /** + * @objective Verify that modifying a preset's levels (rename, delete, add) automatically + * switches the dropdown to "Custom classification levels", and selecting a real preset + * again removes the Custom option from the dropdown. + */ + test( + 'MM-T6212 classification markings: modifying a preset switches dropdown to Custom', + {tag: ['@system_console', '@classification_markings']}, + async ({pw}) => { + const {adminUser, adminClient} = await pw.initSetup(); + + await setClassificationMarkingsFeatureFlag(adminClient, true); + await deleteClassificationMarkingsFieldIfExists(adminClient); + + const {systemConsolePage} = await pw.testBrowser.login(adminUser); + const {page} = systemConsolePage; + await page.goto(CLASSIFICATION_MARKINGS_ADMIN_PATH); + await page.waitForLoadState('networkidle'); + + // # Enable classification markings and select NATO preset + await page.locator('input[name="classificationEnabled"][value="true"]').click(); + await selectClassificationPreset(page, 'NATO'); + + const presetControl = page.getByTestId('classificationPreset'); + await expect(presetControl).toContainText('NATO'); + + // # Rename the first level — this should switch to Custom + const firstLevelInput = page.getByLabel('Classification level name').first(); + await firstLevelInput.clear(); + await firstLevelInput.fill('MY CUSTOM LEVEL'); + + // * Preset dropdown should now show "Custom classification levels" + await expect(presetControl).toContainText('Custom classification levels'); + + // # Open the preset dropdown and verify "Custom classification levels" is listed + await presetControl.click(); + const menu = page.locator('.DropDown__menu'); + await expect(menu).toBeVisible(); + await expect(menu.getByText('Custom classification levels', {exact: true})).toBeVisible(); + + // # Select a real preset (Canada) — should show the confirmation modal + await menu.getByText('Canada', {exact: true}).click(); + await expect(page.getByText('Change classification preset?')).toBeVisible(); + + // # Confirm the preset change + await page.getByRole('button', {name: 'Change preset'}).click(); + + // * Dropdown now shows Canada, no longer Custom + await expect(presetControl).toContainText('Canada'); + + // # Open the dropdown again and verify Custom is no longer listed + await presetControl.click(); + const menuAfterSwitch = page.locator('.DropDown__menu'); + await expect(menuAfterSwitch).toBeVisible(); + await expect(menuAfterSwitch.getByText('Custom classification levels', {exact: true})).not.toBeVisible(); + + // # Close menu by pressing Escape + await page.keyboard.press('Escape'); + + // # Delete a level from the Canada preset + await page.getByRole('button', {name: 'Delete level'}).first().click(); + + // * Should switch back to Custom + await expect(presetControl).toContainText('Custom classification levels'); + }, + ); + /** * @objective Validate that saving with global banner enabled but no level selected shows an error. */ diff --git a/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings_helpers.ts b/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings_helpers.ts index 34cbec1b01f..32e1a8ceb03 100644 --- a/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings_helpers.ts +++ b/e2e-tests/playwright/specs/functional/system_console/site_configuration/classification_markings_helpers.ts @@ -35,6 +35,16 @@ export async function setClassificationMarkingsFeatureFlag(adminClient: Client4, * (clean slate for E2E). Linked field is deleted first to avoid deletion-protection errors. */ export async function deleteClassificationMarkingsFieldIfExists(adminClient: Client4) { + // Delete channel linked fields first (created by channel classification tests). + try { + const channelFields = await adminClient.getPropertyFields(PROPERTY_GROUP, 'channel', TARGET_TYPE, ''); + for (const f of channelFields.filter((f) => f.name === 'channel_classification' && f.delete_at === 0)) { + await adminClient.deletePropertyField(PROPERTY_GROUP, 'channel', f.id); + } + } catch { + // May not exist; ignore. + } + // Clean up both the current 'system' object type and the legacy 'user' object type // to handle stale data from earlier versions of the feature. for (const objectType of [LINKED_OBJECT_TYPE, 'user'] as const) { diff --git a/webapp/channels/src/actions/websocket_actions.ts b/webapp/channels/src/actions/websocket_actions.ts index 843edb78a89..d9cb60c6155 100644 --- a/webapp/channels/src/actions/websocket_actions.ts +++ b/webapp/channels/src/actions/websocket_actions.ts @@ -155,12 +155,11 @@ import {isThreadOpen, isThreadManuallyUnread} from 'selectors/views/threads'; import store from 'stores/redux_store'; import { - GROUP_NAME, - OBJECT_TYPE, - TARGET_TYPE, - TARGET_ID, - LINKED_OBJECT_TYPE, - SYSTEM_FIELD_TARGET_ID, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, } from 'components/admin_console/classification_markings/utils'; import {EntityType, invalidateAccessControlAttributesCache} from 'components/common/hooks/useAccessControlAttributes'; import DialogRouter from 'components/dialog_router'; @@ -345,17 +344,22 @@ export function reconnect() { // Refresh classification fields and values on reconnect when the feature flag is active if (getFeatureFlagValue(state, 'ClassificationMarkings') === 'true') { dispatch( - fetchPropertyFields(GROUP_NAME, OBJECT_TYPE, TARGET_TYPE, TARGET_ID), + fetchPropertyFields( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + ), ); dispatch( fetchPropertyFields( - GROUP_NAME, - LINKED_OBJECT_TYPE, - TARGET_TYPE, - SYSTEM_FIELD_TARGET_ID, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, ), ); - dispatch(fetchSystemPropertyValues(GROUP_NAME)); + dispatch(fetchSystemPropertyValues(CLASSIFICATIONS_GROUP_NAME)); } if (state.websocket.lastDisconnectAt) { diff --git a/webapp/channels/src/components/admin_console/__snapshots__/color_setting.test.tsx.snap b/webapp/channels/src/components/admin_console/__snapshots__/color_setting.test.tsx.snap index 3e35be13cfb..736ab5ba0c5 100644 --- a/webapp/channels/src/components/admin_console/__snapshots__/color_setting.test.tsx.snap +++ b/webapp/channels/src/components/admin_console/__snapshots__/color_setting.test.tsx.snap @@ -127,6 +127,18 @@ exports[`components/ColorSetting should match snapshot, disabled 1`] = ` type="text" value="#fff" /> + + +
= {}): PropertyField { return { id: 'field1', - group_id: GROUP_NAME, - name: FIELD_NAME, + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_TEMPLATE_FIELD_NAME, type: 'select', attrs: {options: []}, target_id: '', - target_type: TARGET_TYPE, - object_type: OBJECT_TYPE, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, create_at: 1000, update_at: 1000, delete_at: 0, @@ -58,13 +61,13 @@ function makePropertyField(overrides: Partial = {}): PropertyFiel function makeLinkedField(overrides: Partial = {}): PropertyField { return { id: 'linked_field1', - group_id: GROUP_NAME, - name: LINKED_FIELD_NAME, + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_SYSTEM_FIELD_NAME, type: 'select', attrs: {actions: []}, - target_id: SYSTEM_FIELD_TARGET_ID, - target_type: TARGET_TYPE, - object_type: LINKED_OBJECT_TYPE, + target_id: CLASSIFICATIONS_FIELD_TARGET_ID, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, linked_field_id: 'field1', create_at: 2000, update_at: 2000, @@ -75,12 +78,32 @@ function makeLinkedField(overrides: Partial = {}): PropertyField }; } +function makeChannelLinkedField(overrides: Partial = {}): PropertyField { + return { + id: 'channel_field1', + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_CHANNEL_FIELD_NAME, + type: 'select', + attrs: {}, + target_id: '', + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + linked_field_id: 'field1', + create_at: 4000, + update_at: 4000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + ...overrides, + }; +} + function makeSystemValue(fieldId: string, optionId: string): PropertyValue { return { id: 'value1', - target_id: SYSTEM_VALUE_TARGET_ID, - target_type: LINKED_OBJECT_TYPE, - group_id: GROUP_NAME, + target_id: CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID, + target_type: CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + group_id: CLASSIFICATIONS_GROUP_NAME, field_id: fieldId, value: optionId, create_at: 3000, @@ -311,6 +334,105 @@ describe('fetchClassificationField', () => { }); }); +describe('fetchChannelClassificationField', () => { + beforeEach(() => { + jest.clearAllMocks(); + + // Reset mockResolvedValueOnce queues that may carry over from the + // fetchClassificationField "stop after 500 items" test. + (Client4.getPropertyFields as jest.Mock).mockReset?.(); + }); + + test('should return the matching channel-linked field from first page', async () => { + const expected = makeChannelLinkedField(); + jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([ + makeChannelLinkedField({id: 'other', name: 'other_field'}), + expected, + ]); + + const result = await fetchChannelClassificationField(); + expect(result).toEqual(expected); + expect(Client4.getPropertyFields).toHaveBeenCalledTimes(1); + expect(Client4.getPropertyFields).toHaveBeenCalledWith( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + '', + expect.any(Object), + ); + }); + + test('should skip channel fields without linked_field_id', async () => { + const orphan = makeChannelLinkedField({id: 'orphan', linked_field_id: ''}); + const linked = makeChannelLinkedField({id: 'linked'}); + jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([orphan, linked]); + + const result = await fetchChannelClassificationField(); + expect(result).toEqual(linked); + }); + + test('should skip soft-deleted channel-linked fields', async () => { + const deleted = makeChannelLinkedField({id: 'deleted', delete_at: 999}); + const active = makeChannelLinkedField({id: 'active'}); + jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([deleted, active]); + + const result = await fetchChannelClassificationField(); + expect(result).toEqual(active); + }); + + test('should paginate using cursor when field not found on first page', async () => { + const page1 = [ + makeChannelLinkedField({id: 'p1', name: 'other1', create_at: 100}), + makeChannelLinkedField({id: 'p2', name: 'other2', create_at: 200}), + ]; + const expected = makeChannelLinkedField({id: 'found'}); + const page2 = [expected]; + + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce(page1). + mockResolvedValueOnce(page2); + + const result = await fetchChannelClassificationField(); + expect(result).toEqual(expected); + expect(Client4.getPropertyFields).toHaveBeenCalledTimes(2); + + const secondCallArgs = (Client4.getPropertyFields as jest.Mock).mock.calls[1]; + expect(secondCallArgs[4]).toEqual({cursorId: 'p2', cursorCreateAt: 200}); + }); + + test('should return undefined when field list is empty', async () => { + jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([]); + + const result = await fetchChannelClassificationField(); + expect(result).toBeUndefined(); + }); + + test('should return undefined when no pages contain a valid channel-linked field', async () => { + jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([ + makeChannelLinkedField({id: 'irrelevant', name: 'other'}), + ]).mockResolvedValueOnce([]); + + const result = await fetchChannelClassificationField(); + expect(result).toBeUndefined(); + }); + + test('should stop after 500 items to avoid infinite loop', async () => { + const makePage = (startId: number) => + Array.from({length: 100}, (_, i) => + makeChannelLinkedField({id: `id_${startId + i}`, name: `other_${startId + i}`, create_at: startId + i}), + ); + + const spy = jest.spyOn(Client4, 'getPropertyFields'); + for (let i = 0; i < 6; i++) { + spy.mockResolvedValueOnce(makePage(i * 100)); + } + + const result = await fetchChannelClassificationField(); + expect(result).toBeUndefined(); + expect(Client4.getPropertyFields).toHaveBeenCalledTimes(5); + }); +}); + describe('ClassificationMarkings component', () => { beforeEach(() => { jest.clearAllMocks(); @@ -405,6 +527,65 @@ describe('ClassificationMarkings component', () => { expect(screen.getByText('Classification levels')).toBeInTheDocument(); }); + test('should not show Custom option in preset dropdown when a named preset is active', async () => { + const usPreset = presets.find((p) => p.id === 'us')!; + const field = makePropertyField({ + attrs: { + options: usPreset.levels.map((l) => ({ + id: l.id, + name: l.name, + color: l.color, + rank: l.rank, + })), + }, + }); + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). + mockResolvedValueOnce([]); // linked field + + renderWithContext(, BASE_STATE); + + await screen.findByText('Classification levels'); + + // The selected value in the dropdown should be US, not Custom + expect(screen.queryByText('Custom classification levels')).not.toBeInTheDocument(); + }); + + test('should show Custom indicator after editing a level', async () => { + const usPreset = presets.find((p) => p.id === 'us')!; + const field = makePropertyField({ + attrs: { + options: usPreset.levels.map((l) => ({ + id: l.id, + name: l.name, + color: l.color, + rank: l.rank, + })), + }, + }); + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). + mockResolvedValueOnce([]); // linked field + + renderWithContext(, BASE_STATE); + + await screen.findByText('Classification levels'); + + const user = userEvent.setup(); + + // Initially shows US preset, not Custom + expect(screen.queryByText('Custom classification levels')).not.toBeInTheDocument(); + + // Edit the first level name to trigger switchToCustom + const nameInputs = screen.getAllByRole('textbox', {name: /Classification level name/i}); + await user.clear(nameInputs[0]); + await user.type(nameInputs[0], 'MODIFIED'); + await user.tab(); + + // Custom should now appear as the selected dropdown value + expect(screen.getByText('Custom classification levels')).toBeInTheDocument(); + }); + test('should detect hasChanges when toggling enabled', async () => { jest.spyOn(Client4, 'getPropertyFields').mockResolvedValueOnce([]); @@ -501,7 +682,8 @@ describe('ClassificationMarkings component', () => { jest.spyOn(Client4, 'getPropertyFields'). mockResolvedValueOnce([field]). // template field - mockResolvedValueOnce([linkedField]); // linked field (existing, no banner actions) + mockResolvedValueOnce([linkedField]). // linked field (existing, no banner actions) + mockResolvedValueOnce([makeChannelLinkedField()]); // channel-linked field exists during save jest.spyOn(Client4, 'patchPropertyField'). mockResolvedValueOnce(patchedTemplate). // patch template mockResolvedValueOnce(linkedField); // patch linked @@ -521,8 +703,8 @@ describe('ClassificationMarkings component', () => { await waitFor(() => { expect(Client4.patchPropertyField).toHaveBeenCalledWith( - GROUP_NAME, - OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, 'field1', expect.objectContaining({ attrs: expect.objectContaining({ @@ -740,7 +922,8 @@ describe('GlobalClassificationIndicators section', () => { jest.spyOn(Client4, 'getPropertyFields'). mockResolvedValueOnce([field]). - mockResolvedValueOnce([linked]); + mockResolvedValueOnce([linked]). + mockResolvedValueOnce([makeChannelLinkedField()]); // channel-linked field already exists during save jest.spyOn(Client4, 'getSystemPropertyValues'). mockResolvedValueOnce([sysValue]); @@ -772,8 +955,8 @@ describe('GlobalClassificationIndicators section', () => { await waitFor(() => { // Template field patched without global_banner in attrs. expect(Client4.patchPropertyField).toHaveBeenCalledWith( - GROUP_NAME, - OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, 'field1', expect.objectContaining({ attrs: expect.objectContaining({options: expect.any(Array)}), @@ -781,7 +964,7 @@ describe('GlobalClassificationIndicators section', () => { ); expect(Client4.patchPropertyField).not.toHaveBeenCalledWith( expect.anything(), - OBJECT_TYPE, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, expect.anything(), expect.objectContaining({ attrs: expect.objectContaining({global_banner: expect.anything()}), @@ -790,8 +973,8 @@ describe('GlobalClassificationIndicators section', () => { // Linked field patched with updated actions (top_and_bottom). expect(Client4.patchPropertyField).toHaveBeenCalledWith( - GROUP_NAME, - LINKED_OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, 'linked_field1', expect.objectContaining({ attrs: expect.objectContaining({ @@ -819,7 +1002,8 @@ describe('GlobalClassificationIndicators section', () => { jest.spyOn(Client4, 'getPropertyFields'). mockResolvedValueOnce([field]). - mockResolvedValueOnce([linked]); + mockResolvedValueOnce([linked]). + mockResolvedValueOnce([makeChannelLinkedField()]); // channel-linked field already exists during save jest.spyOn(Client4, 'patchPropertyField'). mockResolvedValueOnce(patchedTemplate). @@ -840,8 +1024,8 @@ describe('GlobalClassificationIndicators section', () => { await waitFor(() => { // Template field saved without global_banner. expect(Client4.patchPropertyField).toHaveBeenCalledWith( - GROUP_NAME, - OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, 'field1', expect.not.objectContaining({ attrs: expect.objectContaining({global_banner: expect.anything()}), @@ -850,8 +1034,8 @@ describe('GlobalClassificationIndicators section', () => { // Linked field patched with empty actions (banner disabled). expect(Client4.patchPropertyField).toHaveBeenCalledWith( - GROUP_NAME, - LINKED_OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, 'linked_field1', expect.objectContaining({ attrs: expect.objectContaining({actions: []}), @@ -869,12 +1053,13 @@ describe('GlobalClassificationIndicators section', () => { jest.spyOn(Client4, 'getPropertyFields'). mockResolvedValueOnce([field]). - mockResolvedValueOnce([linked]); + mockResolvedValueOnce([linked]). + mockResolvedValueOnce([]); const deleteOrder: string[] = []; const deleteFieldSpy = jest.spyOn(Client4, 'deletePropertyField'); deleteFieldSpy.mockImplementation(async (_group, objectType, _id) => { - deleteOrder.push(objectType === LINKED_OBJECT_TYPE ? `linked:${_id}` : `template:${_id}`); + deleteOrder.push(objectType === CLASSIFICATIONS_SYSTEM_OBJECT_TYPE ? `linked:${_id}` : `template:${_id}`); return {status: 'OK'}; }); @@ -915,3 +1100,213 @@ describe('GlobalClassificationIndicators section', () => { } }); }); + +describe('Channel classification linked field branches', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test('should create channel-linked field when none exists during save', async () => { + const field = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}]}, + }); + const linked = makeLinkedField({attrs: {actions: []}}); + const patchedTemplate = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'MODIFIED', color: '#007A33', rank: 1}]}, + }); + const patchedLinked = makeLinkedField({attrs: {actions: []}}); + const createdChannelField = makeChannelLinkedField(); + + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). // template field load + mockResolvedValueOnce([linked]). // linked field load + mockResolvedValueOnce([]); // channel-linked field lookup during save -> none + + jest.spyOn(Client4, 'patchPropertyField'). + mockResolvedValueOnce(patchedTemplate). + mockResolvedValueOnce(patchedLinked); + + const createSpy = jest.spyOn(Client4, 'createPropertyField'). + mockResolvedValueOnce(createdChannelField); + + renderWithContext(, BASE_STATE); + await screen.findByText('Classification levels'); + + const user = userEvent.setup(); + const nameInput = screen.getByRole('textbox', {name: /Classification level name/i}); + await user.clear(nameInput); + await user.type(nameInput, 'MODIFIED'); + await user.tab(); + + await user.click(await screen.findByText('Save')); + + await waitFor(() => { + expect(createSpy).toHaveBeenCalledWith( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + expect.objectContaining({ + name: CLASSIFICATIONS_CHANNEL_FIELD_NAME, + linked_field_id: 'field1', + }), + ); + }); + await act(async () => {}); + }); + + test('should not create channel-linked field when one already exists during save', async () => { + const field = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}]}, + }); + const linked = makeLinkedField({attrs: {actions: []}}); + const patchedTemplate = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'MODIFIED', color: '#007A33', rank: 1}]}, + }); + const patchedLinked = makeLinkedField({attrs: {actions: []}}); + const existingChannelField = makeChannelLinkedField(); + + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). + mockResolvedValueOnce([linked]). + mockResolvedValueOnce([existingChannelField]); // channel field exists + + jest.spyOn(Client4, 'patchPropertyField'). + mockResolvedValueOnce(patchedTemplate). + mockResolvedValueOnce(patchedLinked); + + const createSpy = jest.spyOn(Client4, 'createPropertyField'); + + const {store} = renderWithContext(, BASE_STATE); + await screen.findByText('Classification levels'); + + const user = userEvent.setup(); + const nameInput = screen.getByRole('textbox', {name: /Classification level name/i}); + await user.clear(nameInput); + await user.type(nameInput, 'MODIFIED'); + await user.tab(); + + await user.click(await screen.findByText('Save')); + + await waitFor(() => { + expect(Client4.patchPropertyField).toHaveBeenCalled(); + }); + await act(async () => {}); + + // Channel field must not be created since one already exists. + expect(createSpy).not.toHaveBeenCalledWith( + expect.anything(), + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + expect.anything(), + ); + + // Existing channel field must be pushed into the store alongside the saved template + // and linked field so consumers that read from Redux get it immediately. + const fieldsById = store.getState().entities.properties.fields.byId; + expect(fieldsById[existingChannelField.id]).toEqual(existingChannelField); + }); + + test('should delete channel-linked field before linked and template when disabling', async () => { + const field = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}]}, + }); + const linked = makeLinkedField({attrs: {actions: []}}); + const channel = makeChannelLinkedField(); + + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). // template field load + mockResolvedValueOnce([linked]). // linked field load + mockResolvedValueOnce([channel]); // channel field lookup during disable + + const deleteOrder: string[] = []; + jest.spyOn(Client4, 'deletePropertyField').mockImplementation(async (_group, objectType, id) => { + if (objectType === CLASSIFICATIONS_CHANNEL_OBJECT_TYPE) { + deleteOrder.push(`channel:${id}`); + } else if (objectType === CLASSIFICATIONS_SYSTEM_OBJECT_TYPE) { + deleteOrder.push(`linked:${id}`); + } else { + deleteOrder.push(`template:${id}`); + } + return {status: 'OK'}; + }); + + // Suppress noisy "not configured to support act" warnings from the bulk state reset. + const origError = console.error; + console.error = (...args: Parameters) => { + if (typeof args[0] === 'string' && args[0].includes('not configured to support act')) { + return; + } + origError(...args); + }; + + try { + renderWithContext(, BASE_STATE); + await screen.findByText('Global Classification Indicators'); + + const user = userEvent.setup(); + + await act(async () => { + await user.click(screen.getByTestId('classificationEnabledfalse')); + }); + + await act(async () => { + await user.click(screen.getByText('Save')); + }); + + await waitFor(() => { + expect(deleteOrder).toHaveLength(3); + }); + await act(async () => {}); + + expect(deleteOrder[0]).toBe(`channel:${channel.id}`); + expect(deleteOrder[1]).toBe('linked:linked_field1'); + expect(deleteOrder[2]).toBe('template:field1'); + } finally { + console.error = origError; + } + }); + + test('should not attempt to delete channel-linked field when none exists', async () => { + const field = makePropertyField({ + attrs: {options: [{id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}]}, + }); + const linked = makeLinkedField({attrs: {actions: []}}); + + jest.spyOn(Client4, 'getPropertyFields'). + mockResolvedValueOnce([field]). + mockResolvedValueOnce([linked]). + mockResolvedValueOnce([]); // no channel field exists + + const deletedTypes: string[] = []; + jest.spyOn(Client4, 'deletePropertyField').mockImplementation(async (_group, objectType) => { + deletedTypes.push(objectType); + return {status: 'OK'}; + }); + + const origError = console.error; + console.error = (...args: Parameters) => { + if (typeof args[0] === 'string' && args[0].includes('not configured to support act')) { + return; + } + origError(...args); + }; + + try { + renderWithContext(, BASE_STATE); + await screen.findByText('Global Classification Indicators'); + + const user = userEvent.setup(); + + await act(async () => { + await user.click(screen.getByTestId('classificationEnabledfalse')); + }); + await act(async () => { + await user.click(screen.getByText('Save')); + }); + + await waitFor(() => { + expect(deletedTypes).toEqual([CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE]); + }); + } finally { + console.error = origError; + } + }); +}); diff --git a/webapp/channels/src/components/admin_console/classification_markings/classification_markings.tsx b/webapp/channels/src/components/admin_console/classification_markings/classification_markings.tsx index e780ff8f274..50bc328ac6f 100644 --- a/webapp/channels/src/components/admin_console/classification_markings/classification_markings.tsx +++ b/webapp/channels/src/components/admin_console/classification_markings/classification_markings.tsx @@ -37,12 +37,15 @@ import { DEFAULT_GLOBAL_BANNER, DISPLAY_BANNER_TOP, actionsToGlobalBanner, + fetchChannelClassificationField, fetchClassificationField, fetchLinkedClassificationField, fetchSystemClassificationValue, processClassificationField, + saveCreateChannelLinkedField, saveCreateField, saveCreateLinkedField, + saveDeleteChannelLinkedField, saveDeleteField, saveDeleteLinkedField, savePatchField, @@ -216,20 +219,21 @@ export default function ClassificationMarkings({disabled}: Props) { }, []); const presetDropdownOptions = useMemo((): ValueType[] => { - return [ - ...presets.map((p) => ({value: p.id, label: p.label})), - { + const options = presets.map((p) => ({value: p.id, label: p.label})); + if (presetId === PRESET_CUSTOM) { + options.push({ value: PRESET_CUSTOM, label: formatMessage({ id: 'admin.classification_markings.preset.custom', defaultMessage: 'Custom classification levels', }), - }, - ]; - }, [formatMessage]); + }); + } + return options; + }, [formatMessage, presetId]); const presetDropdownValue = useMemo(() => { - return presetDropdownOptions.find((o) => o.value === presetId) ?? presetDropdownOptions[presetDropdownOptions.length - 1]!; + return presetDropdownOptions.find((o) => o.value === presetId) ?? presetDropdownOptions[0]!; }, [presetDropdownOptions, presetId]); const handlePresetDropdownChange = useCallback((selected: ValueType | null) => { @@ -237,10 +241,6 @@ export default function ClassificationMarkings({disabled}: Props) { return; } const newPresetId = selected.value; - if (newPresetId === PRESET_CUSTOM) { - setPresetId(PRESET_CUSTOM); - return; - } if (levels.length > 0) { setConfirmPresetSwitch(newPresetId); return; @@ -370,9 +370,16 @@ export default function ClassificationMarkings({disabled}: Props) { savedLinked = await savePatchLinkedField(savedLinked.id, effectiveBanner); } + // Ensure the channel_classification linked field exists as part of the set. // Push saved fields into Redux eagerly so the banner updates // atomically rather than waiting for out-of-order WS events. - dispatch({type: PropertyTypes.RECEIVED_PROPERTY_FIELDS, data: {fields: [savedTemplate, savedLinked]}}); + const existingChannelField = await fetchChannelClassificationField(); + if (existingChannelField) { + dispatch({type: PropertyTypes.RECEIVED_PROPERTY_FIELDS, data: {fields: [savedTemplate, savedLinked, existingChannelField]}}); + } else { + const savedChannelField = await saveCreateChannelLinkedField(savedTemplate.id); + dispatch({type: PropertyTypes.RECEIVED_PROPERTY_FIELDS, data: {fields: [savedTemplate, savedLinked, savedChannelField]}}); + } setExistingField(savedTemplate); setExistingLinkedField(savedLinked); @@ -383,7 +390,13 @@ export default function ClassificationMarkings({disabled}: Props) { setInitialGlobalBanner(effectiveBanner); setInitialEnabled(true); } else if (templateField) { - // Linked field must be deleted before the template (deletion protection). + // Linked fields must be deleted before the template (deletion protection). + // Order: channel field -> system field -> template. + const channelField = await fetchChannelClassificationField(); + if (channelField) { + await saveDeleteChannelLinkedField(channelField.id); + dispatch({type: PropertyTypes.PROPERTY_FIELD_DELETED, data: {fieldId: channelField.id}}); + } if (linkedField) { await saveDeleteLinkedField(linkedField.id); dispatch({type: PropertyTypes.PROPERTY_FIELD_DELETED, data: {fieldId: linkedField.id}}); diff --git a/webapp/channels/src/components/admin_console/classification_markings/utils/index.ts b/webapp/channels/src/components/admin_console/classification_markings/utils/index.ts index e8e8b7a1004..78631443dbe 100644 --- a/webapp/channels/src/components/admin_console/classification_markings/utils/index.ts +++ b/webapp/channels/src/components/admin_console/classification_markings/utils/index.ts @@ -8,30 +8,48 @@ import {Client4} from 'mattermost-redux/client'; import type {ClassificationLevel} from './presets'; import {PRESET_CUSTOM, presets} from './presets'; -export const GROUP_NAME = 'classification_markings'; +// --------------------------------------------------------------------------- +// Property-field identifiers for the classification-markings feature. +// +// Three logical fields participate: +// 1. Template field — canonical schema (Linked Properties template). The +// admin defines the level options here; per-channel +// and system fields link to it and inherit them. +// 2. System field — linked-to-template; drives the GLOBAL banner. Lives +// on the dedicated 'system' object-type path +// introduced in #36250. +// 3. Channel field — linked-to-template; drives PER-CHANNEL banners. +// +// All three fields are scoped server-side as system fields, so they share the +// same field-level target attributes (`target_type='system'`, `target_id=''`). +// Property *values* for the system field are stored on the dedicated system +// endpoint and use the sentinel target_id 'system'. +// --------------------------------------------------------------------------- -// OBJECT_TYPE is 'template' so the classification field acts as the canonical schema -// (a Linked Properties template). Per-channel fields will link to it and inherit its options. -export const OBJECT_TYPE = 'template'; -export const TARGET_TYPE = 'system'; +// Property-field group identifying all classification-markings entities. +export const CLASSIFICATIONS_GROUP_NAME = 'classification_markings'; -// TARGET_ID is intentionally empty for system-scoped template fields. -export const TARGET_ID = ''; -export const FIELD_NAME = 'classification'; -export const LINKED_FIELD_NAME = 'system_classification'; +// Field-level target attributes shared by template, system, and channel fields. +// `target_type` is always 'system'; `target_id` is empty for system-scoped +// field definitions (the server canonicalizes both). +export const CLASSIFICATIONS_FIELD_TARGET_TYPE = 'system'; +export const CLASSIFICATIONS_FIELD_TARGET_ID = ''; -// The linked field uses the 'system' object type introduced in #36250. -// System fields are canonicalized server-side: target_type='system', target_id=''. -// System values use the sentinel target_id 'system' and dedicated API routes. -export const LINKED_OBJECT_TYPE = 'system'; +// Template field — the canonical schema. +export const CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE = 'template'; +export const CLASSIFICATIONS_TEMPLATE_FIELD_NAME = 'classification'; -// System-scoped fields have target_id '' on the field definition. -export const SYSTEM_FIELD_TARGET_ID = ''; +// System field — drives the global banner. Property *values* live on the +// dedicated system endpoint and use the sentinel target_id 'system'. +export const CLASSIFICATIONS_SYSTEM_OBJECT_TYPE = 'system'; +export const CLASSIFICATIONS_SYSTEM_FIELD_NAME = 'system_classification'; +export const CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID = 'system'; -// The sentinel target_id used by the server for system-scoped property values. -export const SYSTEM_VALUE_TARGET_ID = 'system'; +// Channel field — drives the per-channel banner. +export const CLASSIFICATIONS_CHANNEL_OBJECT_TYPE = 'channel'; +export const CLASSIFICATIONS_CHANNEL_FIELD_NAME = 'channel_classification'; -// Actions stored on the linked field's attrs.actions to control banner display. +// Actions stored on the linked fields' attrs.actions to control banner placement. export const DISPLAY_BANNER_TOP = 'display_banner_top'; export const DISPLAY_BANNER_BOTTOM = 'display_banner_bottom'; @@ -143,8 +161,14 @@ export async function fetchClassificationField(): Promise f.name === FIELD_NAME && f.delete_at === 0); + const fields = await Client4.getPropertyFields( // eslint-disable-line no-await-in-loop + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + {cursorId, cursorCreateAt}, + ); + const found = fields.find((f: PropertyField) => f.name === CLASSIFICATIONS_TEMPLATE_FIELD_NAME && f.delete_at === 0); if (found || fields.length === 0) { return found; } @@ -160,11 +184,11 @@ export async function fetchClassificationField(): Promise { const options = levelsToOptions(levels); - return Client4.createPropertyField(GROUP_NAME, OBJECT_TYPE, { - name: FIELD_NAME, + return Client4.createPropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, { + name: CLASSIFICATIONS_TEMPLATE_FIELD_NAME, type: 'select' as PropertyField['type'], - target_type: TARGET_TYPE, - target_id: TARGET_ID, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + target_id: CLASSIFICATIONS_FIELD_TARGET_ID, attrs: {options, managed: 'admin'}, permission_field: 'sysadmin', permission_values: 'sysadmin', @@ -173,17 +197,17 @@ export async function saveCreateField(levels: ClassificationLevel[]): Promise { - await Client4.deletePropertyField(GROUP_NAME, OBJECT_TYPE, fieldId); + await Client4.deletePropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, fieldId); } export async function savePatchField(fieldId: string, levels: ClassificationLevel[]): Promise { const options = levelsToOptions(levels); - return Client4.patchPropertyField(GROUP_NAME, OBJECT_TYPE, fieldId, { + return Client4.patchPropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, fieldId, { attrs: {options}, } as Partial); } -// --- Linked system classification field API --- +// --- System field API (drives the global banner) --- export async function fetchLinkedClassificationField(): Promise { const maxItems = 500; @@ -192,8 +216,14 @@ export async function fetchLinkedClassificationField(): Promise f.name === LINKED_FIELD_NAME && f.delete_at === 0 && f.linked_field_id); + const fields = await Client4.getPropertyFields( // eslint-disable-line no-await-in-loop + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + {cursorId, cursorCreateAt}, + ); + const found = fields.find((f: PropertyField) => f.name === CLASSIFICATIONS_SYSTEM_FIELD_NAME && f.delete_at === 0 && f.linked_field_id); if (found || fields.length === 0) { return found; } @@ -208,11 +238,11 @@ export async function fetchLinkedClassificationField(): Promise { - return Client4.createPropertyField(GROUP_NAME, LINKED_OBJECT_TYPE, { - name: LINKED_FIELD_NAME, + return Client4.createPropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, { + name: CLASSIFICATIONS_SYSTEM_FIELD_NAME, type: 'select' as PropertyField['type'], - target_type: TARGET_TYPE, - target_id: SYSTEM_FIELD_TARGET_ID, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + target_id: CLASSIFICATIONS_FIELD_TARGET_ID, linked_field_id: templateFieldId, attrs: { actions: placementToActions(config), @@ -221,7 +251,7 @@ export async function saveCreateLinkedField(templateFieldId: string, config: Glo } export async function savePatchLinkedField(linkedFieldId: string, config: GlobalBannerConfig): Promise { - return Client4.patchPropertyField(GROUP_NAME, LINKED_OBJECT_TYPE, linkedFieldId, { + return Client4.patchPropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, linkedFieldId, { attrs: { actions: placementToActions(config), }, @@ -229,7 +259,7 @@ export async function savePatchLinkedField(linkedFieldId: string, config: Global } export async function saveDeleteLinkedField(fieldId: string): Promise { - await Client4.deletePropertyField(GROUP_NAME, LINKED_OBJECT_TYPE, fieldId); + await Client4.deletePropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, fieldId); } // --- System classification property value API --- @@ -239,7 +269,7 @@ export async function saveDeleteLinkedField(fieldId: string): Promise { * Uses the dedicated system values endpoint (no target_id in URL). */ export async function fetchSystemClassificationValue(linkedFieldId: string): Promise { - const values = await Client4.getSystemPropertyValues(GROUP_NAME); + const values = await Client4.getSystemPropertyValues(CLASSIFICATIONS_GROUP_NAME); const match = ((values as Array>) ?? []).find((v) => v.field_id === linkedFieldId); return match?.value; } @@ -250,7 +280,51 @@ export async function fetchSystemClassificationValue(linkedFieldId: string): Pro * Returns the saved property values so callers can eagerly update the store. */ export async function saveUpsertSystemValue(linkedFieldId: string, optionId: string): Promise>> { - return Client4.patchSystemPropertyValues(GROUP_NAME, [ + return Client4.patchSystemPropertyValues(CLASSIFICATIONS_GROUP_NAME, [ {field_id: linkedFieldId, value: optionId}, ]); } + +// --- Channel field API (drives per-channel banners) --- + +export async function fetchChannelClassificationField(): Promise { + const maxItems = 500; + let fetched = 0; + let cursorId: string | undefined; + let cursorCreateAt: number | undefined; + + while (fetched < maxItems) { + const fields = await Client4.getPropertyFields( // eslint-disable-line no-await-in-loop + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + {cursorId, cursorCreateAt}, + ); + const found = fields.find((f: PropertyField) => f.name === CLASSIFICATIONS_CHANNEL_FIELD_NAME && f.delete_at === 0 && f.linked_field_id); + if (found || fields.length === 0) { + return found; + } + + fetched += fields.length; + const last = fields[fields.length - 1]; + cursorId = last.id; + cursorCreateAt = last.create_at; + } + + return undefined; +} + +export async function saveCreateChannelLinkedField(templateFieldId: string): Promise { + return Client4.createPropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, { + name: CLASSIFICATIONS_CHANNEL_FIELD_NAME, + type: 'select' as PropertyField['type'], + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + target_id: CLASSIFICATIONS_FIELD_TARGET_ID, + linked_field_id: templateFieldId, + }); +} + +export async function saveDeleteChannelLinkedField(fieldId: string): Promise { + await Client4.deletePropertyField(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, fieldId); +} diff --git a/webapp/channels/src/components/admin_console/classification_markings/utils/preset_dropdown_styles.ts b/webapp/channels/src/components/admin_console/classification_markings/utils/preset_dropdown_styles.ts index 18c9ccb4c52..e8395f69f94 100644 --- a/webapp/channels/src/components/admin_console/classification_markings/utils/preset_dropdown_styles.ts +++ b/webapp/channels/src/components/admin_console/classification_markings/utils/preset_dropdown_styles.ts @@ -66,6 +66,6 @@ export const classificationPresetDropdownStyles: StylesConfig = { }), menuPortal: (provided) => ({ ...provided, - zIndex: 200, + zIndex: 1100, }), }; diff --git a/webapp/channels/src/components/channel_banner/channel_banner.tsx b/webapp/channels/src/components/channel_banner/channel_banner.tsx index cc0e3449a8d..e826cd7d0ff 100644 --- a/webapp/channels/src/components/channel_banner/channel_banner.tsx +++ b/webapp/channels/src/components/channel_banner/channel_banner.tsx @@ -6,12 +6,14 @@ import {useIntl} from 'react-intl'; import {useSelector} from 'react-redux'; import {WithTooltip} from '@mattermost/shared/components/tooltip'; +import type {ChannelBanner} from '@mattermost/types/channels'; import {selectShowChannelBanner} from 'mattermost-redux/selectors/entities/channel_banner'; import {getChannelBanner} from 'mattermost-redux/selectors/entities/channels'; import {getLicense} from 'mattermost-redux/selectors/entities/general'; import {getContrastingSimpleColor} from 'mattermost-redux/utils/theme_utils'; +import useChannelClassificationBanner from 'components/common/hooks/useChannelClassificationBanner'; import Markdown from 'components/markdown'; import {isMinimumEnterpriseAdvancedLicense} from 'utils/license_utils'; @@ -35,7 +37,17 @@ export default function ChannelBanner({channelId}: Props) { const license = useSelector(getLicense); const licenseEnabled = isMinimumEnterpriseAdvancedLicense(license); const channelBannerConfigured = useSelector((state: GlobalState) => selectShowChannelBanner(state, channelId)); - const showChannelBanner = licenseEnabled && channelBannerConfigured; + const showNativeBanner = licenseEnabled && channelBannerConfigured; + + const classificationBanner = useChannelClassificationBanner(channelId); + + // Classification property value takes priority over native banner_info + const effectiveBanner: ChannelBanner | undefined = classificationBanner.hasClassification ? + classificationBanner.classificationBanner : + channelBannerInfo; + + const showBanner = classificationBanner.hasClassification || showNativeBanner; + const textContainerRef = useRef(null); const [tooltipNeeded, setTooltipNeeded] = React.useState(false); @@ -48,31 +60,30 @@ export default function ChannelBanner({channelId}: Props) { const isOverflowingVertically = textContainerRef.current.offsetHeight < textContainerRef.current.scrollHeight; setTooltipNeeded(isOverflowingHorizontally || isOverflowingVertically); - }, [channelBannerInfo?.text]); + }, [effectiveBanner?.text]); const intl = useIntl(); const channelBannerTextAriaLabel = intl.formatMessage({id: 'channel_banner.aria_label', defaultMessage: 'Channel banner text'}); const content = ( ); const channelBannerStyle = useMemo(() => { return { - backgroundColor: channelBannerInfo?.background_color, + backgroundColor: effectiveBanner?.background_color, }; - }, [channelBannerInfo]); + }, [effectiveBanner]); const channelBannerTextStyle = useMemo(() => { - // this is just to satisfy type checks. - if (!channelBannerInfo || !channelBannerInfo.background_color) { + if (!effectiveBanner || !effectiveBanner.background_color) { return {}; } - const color = getContrastingSimpleColor(channelBannerInfo.background_color); + const color = getContrastingSimpleColor(effectiveBanner.background_color); // The CSS variable is declared here, and is being used in the stylesheet being imported in this component. // This is needed because if the user sets background color a share of blue similar to the default link color, @@ -82,9 +93,9 @@ export default function ChannelBanner({channelId}: Props) { color, '--channel-banner-text-color': color, }; - }, [channelBannerInfo]); + }, [effectiveBanner]); - if (!channelBannerInfo || !showChannelBanner) { + if (!effectiveBanner || !showBanner) { return null; } diff --git a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.scss b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.scss index b88c0b901cb..d01aa15cc76 100644 --- a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.scss +++ b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.scss @@ -8,6 +8,10 @@ flex-direction: column; gap: 32px; + &--with-save-panel { + padding-bottom: 80px; + } + &__configurationDivider { border: none; border-top: 1px solid rgba(var(--center-channel-color-rgb), 0.08); @@ -62,7 +66,8 @@ right: 5px; } - #channel_banner_banner_text_textbox { + #channel_banner_banner_text_textbox, + #channel_classification_banner_text_textbox { min-height: 40px; max-height: 200px; } @@ -105,5 +110,39 @@ .AdvancedTextbox { margin-top: 0 !important; } + + .DropdownInput.Input_container { + margin-top: 0; + + .Input_fieldset { + padding: 0; + border: none; + box-shadow: none; + + &:hover, + &:focus-within { + border: none; + box-shadow: none; + } + } + + .Input_wrapper { + padding: 0; + margin: 0; + } + + .DropDown__control { + height: 40px; + min-height: 40px; + } + + .DropDown__value-container { + height: 38px; + } + + .DropDown__indicators { + height: 38px; + } + } } } diff --git a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.test.tsx b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.test.tsx index 8eddbd5ef36..59172664479 100644 --- a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.test.tsx +++ b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.test.tsx @@ -2,6 +2,12 @@ // See LICENSE.txt for license information. import React from 'react'; +import type {MockStoreEnhanced} from 'redux-mock-store'; + +import {PropertyTypes} from 'mattermost-redux/action_types'; + +import useChannelClassificationBanner from 'components/common/hooks/useChannelClassificationBanner'; +import useClassificationMarkings from 'components/common/hooks/useClassificationMarkings'; import {renderWithContext, screen, userEvent, waitFor} from 'tests/react_testing_utils'; import {TestHelper} from 'utils/test_helper'; @@ -25,6 +31,8 @@ jest.mock('mattermost-redux/client', () => ({ {remote_id: 'remote1', name: 'nebula', display_name: 'Nebula Networks'}, {remote_id: 'remote2', name: 'cascade', display_name: 'Cascade Collaborative'}, ]), + getPropertyValues: jest.fn().mockResolvedValue([]), + patchPropertyValues: jest.fn().mockResolvedValue([]), }, })); @@ -36,6 +44,29 @@ jest.mock('mattermost-redux/selectors/entities/shared_channels', () => { }; }); +jest.mock('components/common/hooks/useChannelClassificationBanner'); +jest.mock('components/common/hooks/useClassificationMarkings'); + +const mockedUseClassificationMarkings = useClassificationMarkings as jest.MockedFunction; +const mockedUseChannelClassificationBanner = useChannelClassificationBanner as jest.MockedFunction; + +// Default classification state: feature unavailable. Individual tests can override. +beforeEach(() => { + mockedUseClassificationMarkings.mockReturnValue({ + available: false, + loading: false, + templateField: null, + channelField: null, + levels: [], + }); + mockedUseChannelClassificationBanner.mockReturnValue({ + hasClassification: false, + classificationBanner: undefined, + classificationId: undefined, + bannerText: undefined, + }); +}); + // Mock the ShowFormat component to make it easier to test jest.mock('components/advanced_text_editor/show_formatting/show_formatting', () => ( jest.fn().mockImplementation((props) => ( @@ -711,4 +742,338 @@ describe('ChannelSettingsConfigurationTab', () => { }); }); }); + + describe('Classification', () => { + const SYSADMIN_USER_ID = 'sysadmin_user_1'; + const sysAdminState = { + entities: { + users: { + currentUserId: SYSADMIN_USER_ID, + profiles: { + [SYSADMIN_USER_ID]: {id: SYSADMIN_USER_ID, roles: 'system_admin system_user'}, + }, + }, + }, + }; + + const TEMPLATE_FIELD_ID = 'template_field_1'; + const CHANNEL_FIELD_ID = 'channel_field_1'; + const LEVEL_UNCLASSIFIED = {id: 'lvl_unclass', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}; + const LEVEL_SECRET = {id: 'lvl_secret', name: 'SECRET', color: '#C8102E', rank: 2}; + + const templateField = { + id: TEMPLATE_FIELD_ID, + group_id: 'classification_markings', + name: 'classification', + type: 'select' as const, + attrs: {options: [LEVEL_UNCLASSIFIED, LEVEL_SECRET]}, + target_id: '', + target_type: 'system', + object_type: 'template', + create_at: 1, + update_at: 1, + delete_at: 0, + created_by: 'u1', + updated_by: 'u1', + }; + + const channelField = { + ...templateField, + id: CHANNEL_FIELD_ID, + name: 'channel_classification', + object_type: 'channel', + linked_field_id: TEMPLATE_FIELD_ID, + attrs: {}, + }; + + function enableClassification(initialBanner: {hasClassification: boolean; classificationId?: string; bannerText?: string} = {hasClassification: false}) { + mockedUseClassificationMarkings.mockReturnValue({ + available: true, + loading: false, + templateField, + channelField, + levels: [LEVEL_UNCLASSIFIED, LEVEL_SECRET], + }); + mockedUseChannelClassificationBanner.mockReturnValue({ + hasClassification: initialBanner.hasClassification, + classificationBanner: initialBanner.hasClassification ? { + enabled: true, + text: initialBanner.bannerText || '', + background_color: '#007A33', + } : undefined, + classificationId: initialBanner.classificationId, + bannerText: initialBanner.bannerText, + }); + } + + it('renders the Classification section when feature is available', () => { + enableClassification(); + renderWithContext( + , + sysAdminState, + ); + + expect(screen.getByText('Classification')).toBeInTheDocument(); + expect(screen.getByTestId('channelClassificationToggle-button')).toBeInTheDocument(); + }); + + it('does not render the Classification section when feature is unavailable', () => { + renderWithContext(, sysAdminState); + + expect(screen.queryByText('Classification')).not.toBeInTheDocument(); + }); + + it('does not render the Classification section for non-sysadmin users', () => { + enableClassification(); + renderWithContext( + , + ); + + expect(screen.queryByText('Classification')).not.toBeInTheDocument(); + }); + + it('auto-selects the lowest-rank level when classification is toggled on', async () => { + const {Client4} = require('mattermost-redux/client'); + const {patchChannel} = require('mattermost-redux/actions/channels'); + patchChannel.mockReturnValue({type: 'MOCK_ACTION', data: {}}); + Client4.patchPropertyValues.mockClear(); + enableClassification(); + + renderWithContext( + , + sysAdminState, + ); + + await userEvent.click(screen.getByTestId('channelClassificationToggle-button')); + + // The lowest-rank level (UNCLASSIFIED) should be auto-selected in the dropdown. + const dropdown = screen.getByTestId('channelClassificationLevel'); + expect(dropdown).toHaveTextContent(LEVEL_UNCLASSIFIED.name); + + // Save button should be enabled since a level is pre-selected. + const saveButton = await screen.findByRole('button', {name: 'Save'}); + expect(saveButton).toBeEnabled(); + }); + + it('saves banner_info via patchChannel when banner text is edited while classification is active', async () => { + const {patchChannel} = require('mattermost-redux/actions/channels'); + patchChannel.mockReturnValue({type: 'MOCK_ACTION', data: {}}); + + enableClassification({ + hasClassification: true, + classificationId: LEVEL_UNCLASSIFIED.id, + bannerText: `**${LEVEL_UNCLASSIFIED.name}**`, + }); + + renderWithContext( + , + sysAdminState, + ); + + const textInput = await screen.findByTestId('channel_banner_banner_text_textbox'); + await userEvent.clear(textInput); + await userEvent.type(textInput, 'Updated text'); + + const saveButton = await screen.findByRole('button', {name: 'Save'}); + await userEvent.click(saveButton); + + await waitFor(() => { + expect(patchChannel).toHaveBeenCalledWith( + 'channel1', + expect.objectContaining({ + banner_info: expect.objectContaining({ + enabled: true, + text: 'Updated text', + }), + }), + ); + }); + }); + + it('does not call patchPropertyValues when classification enabled/id has not changed', async () => { + const {Client4} = require('mattermost-redux/client'); + const {patchChannel} = require('mattermost-redux/actions/channels'); + patchChannel.mockReturnValue({type: 'MOCK_ACTION', data: {}}); + + enableClassification({ + hasClassification: true, + classificationId: LEVEL_UNCLASSIFIED.id, + bannerText: `**${LEVEL_UNCLASSIFIED.name}**`, + }); + + renderWithContext( + , + sysAdminState, + ); + + // Edit only the banner text without changing the classification toggle or level. + const textInput = await screen.findByTestId('channel_banner_banner_text_textbox'); + await userEvent.clear(textInput); + await userEvent.type(textInput, 'Edited banner'); + + const saveButton = await screen.findByRole('button', {name: 'Save'}); + await userEvent.click(saveButton); + + // patchChannel should be called (banner text changed), but patchPropertyValues + // should NOT be called because classification enabled/id are unchanged. + await waitFor(() => { + expect(patchChannel).toHaveBeenCalled(); + }); + + expect(Client4.patchPropertyValues).not.toHaveBeenCalled(); + }); + + it('removes classification by patching value to null and dispatching PROPERTY_VALUE_DELETED', async () => { + const {Client4} = require('mattermost-redux/client'); + Client4.patchPropertyValues.mockResolvedValueOnce([]); + enableClassification({ + hasClassification: true, + classificationId: LEVEL_UNCLASSIFIED.id, + bannerText: `**${LEVEL_UNCLASSIFIED.name}**`, + }); + + const {store} = renderWithContext( + , + sysAdminState, + {useMockedStore: true}, + ); + + // Toggle classification off (it starts on because of `hasClassification: true`). + await userEvent.click(screen.getByTestId('channelClassificationToggle-button')); + + const saveButton = await screen.findByRole('button', {name: 'Save'}); + await userEvent.click(saveButton); + + await waitFor(() => { + expect(Client4.patchPropertyValues).toHaveBeenCalledWith( + 'classification_markings', + 'channel', + 'channel1', + [{field_id: CHANNEL_FIELD_ID, value: null}], + ); + }); + + await waitFor(() => { + const actions = (store as unknown as MockStoreEnhanced).getActions(); + expect(actions.some((a) => a.type === PropertyTypes.PROPERTY_VALUE_DELETED)).toBe(true); + }); + }); + + it('resets classification form to initial state when Reset is clicked', async () => { + enableClassification({ + hasClassification: true, + classificationId: LEVEL_UNCLASSIFIED.id, + bannerText: `**${LEVEL_UNCLASSIFIED.name}**`, + }); + + renderWithContext( + , + sysAdminState, + ); + + // Toggle off → triggers changes → Save panel appears with Reset. + const toggle = screen.getByTestId('channelClassificationToggle-button'); + await userEvent.click(toggle); + + const resetButton = await screen.findByRole('button', {name: 'Reset'}); + await userEvent.click(resetButton); + + // After reset, the Save/Reset panel should be gone and the toggle re-enabled. + await waitFor(() => { + expect(screen.queryByRole('button', {name: 'Reset'})).not.toBeInTheDocument(); + }); + expect(toggle).toHaveClass('active'); + }); + + it('shows an error in the SaveChangesPanel when patchPropertyValues rejects', async () => { + const {Client4} = require('mattermost-redux/client'); + const {patchChannel} = require('mattermost-redux/actions/channels'); + patchChannel.mockReturnValue({type: 'MOCK_ACTION', data: {}}); + Client4.patchPropertyValues.mockRejectedValueOnce({message: 'Server boom'}); + + // Start without classification, so toggling it on creates a classification change. + enableClassification({hasClassification: false}); + + renderWithContext( + , + sysAdminState, + ); + + // Enable classification toggle — lowest-rank level is auto-selected. + await userEvent.click(screen.getByTestId('channelClassificationToggle-button')); + + // Save button should be enabled (level auto-selected). + const saveButton = await screen.findByRole('button', {name: 'Save'}); + expect(saveButton).toBeEnabled(); + + // Click save to trigger the patchPropertyValues rejection. + await userEvent.click(saveButton); + + await waitFor(() => { + expect(screen.getByText(/Server boom/)).toBeInTheDocument(); + }); + }); + + it('shows an error when patchPropertyValues rejects with pre-existing classification', async () => { + const {Client4} = require('mattermost-redux/client'); + const {patchChannel} = require('mattermost-redux/actions/channels'); + patchChannel.mockReturnValue({type: 'MOCK_ACTION', data: {}}); + Client4.patchPropertyValues.mockRejectedValueOnce({message: 'Server boom'}); + + // Start classified → toggle off → toggle back on triggers hasClassificationChanges. + enableClassification({ + hasClassification: true, + classificationId: LEVEL_UNCLASSIFIED.id, + bannerText: `**${LEVEL_UNCLASSIFIED.name}**`, + }); + + renderWithContext( + , + sysAdminState, + ); + + // Toggle off then on to create a classification state change. + const toggle = screen.getByTestId('channelClassificationToggle-button'); + await userEvent.click(toggle); + await userEvent.click(toggle); + + // Now toggle off again — this creates a "disable" change that calls patchPropertyValues(null). + await userEvent.click(toggle); + + const saveButton = await screen.findByRole('button', {name: 'Save'}); + await userEvent.click(saveButton); + + await waitFor(() => { + const errorPanel = screen.getByText(/Server boom/).closest('.SaveChangesPanel'); + expect(errorPanel).toHaveClass('error'); + }); + }); + }); }); diff --git a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.tsx b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.tsx index 8dc703ea236..9e0d9b70176 100644 --- a/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.tsx +++ b/webapp/channels/src/components/channel_settings_modal/channel_settings_configuration_tab.tsx @@ -2,21 +2,34 @@ // See LICENSE.txt for license information. import React, {useCallback, useEffect, useMemo, useRef, useState} from 'react'; -import {useIntl} from 'react-intl'; +import {FormattedMessage, useIntl} from 'react-intl'; import {useDispatch, useSelector} from 'react-redux'; import type {Channel} from '@mattermost/types/channels'; import type {ServerError} from '@mattermost/types/errors'; +import {PropertyTypes} from 'mattermost-redux/action_types'; import {patchChannel} from 'mattermost-redux/actions/channels'; import {fetchChannelRemotes} from 'mattermost-redux/actions/shared_channels'; import {Client4} from 'mattermost-redux/client'; import {isChannelAutotranslated as isChannelAutotranslatedSelector} from 'mattermost-redux/selectors/entities/channels'; import {getRemotesForChannel} from 'mattermost-redux/selectors/entities/shared_channels'; +import {isCurrentUserSystemAdmin} from 'mattermost-redux/selectors/entities/users'; +import {ColorSwatch, LevelOptionLabel} from 'components/admin_console/classification_markings/classification_markings_styled'; +import { + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, +} from 'components/admin_console/classification_markings/utils'; +import {classificationPresetDropdownStyles} from 'components/admin_console/classification_markings/utils/preset_dropdown_styles'; import ColorInput from 'components/color_input'; +import useChannelClassificationBanner from 'components/common/hooks/useChannelClassificationBanner'; +import useClassificationMarkings from 'components/common/hooks/useClassificationMarkings'; import useDidUpdate from 'components/common/hooks/useDidUpdate'; import ConfirmModal from 'components/confirm_modal'; +import DropdownInput from 'components/dropdown_input'; +import type {ValueType} from 'components/dropdown_input'; +import SectionNotice from 'components/section_notice'; import type {TextboxElement} from 'components/textbox'; import Toggle from 'components/toggle'; import AdvancedTextbox from 'components/widgets/advanced_textbox/advanced_textbox'; @@ -30,8 +43,8 @@ import type {WorkspaceWithStatus} from './share_channel_with_workspaces/types'; import './channel_settings_configuration_tab.scss'; -const CHANNEL_BANNER_MAX_CHARACTER_LIMIT = 1024; -const CHANNEL_BANNER_MIN_CHARACTER_LIMIT = 0; +export const CHANNEL_BANNER_MAX_CHARACTER_LIMIT = 1024; +export const CHANNEL_BANNER_MIN_CHARACTER_LIMIT = 0; const DEFAULT_CHANNEL_BANNER = { enabled: false, @@ -88,6 +101,96 @@ function ChannelSettingsConfigurationTab({ const [characterLimitExceeded, setCharacterLimitExceeded] = useState(false); const hasBannerChanges = bannerHasChanges(initialBannerInfo, updatedChannelBanner); + const classificationBanner = useChannelClassificationBanner(channel.id); + + const classification = useClassificationMarkings(); + const isSystemAdmin = useSelector(isCurrentUserSystemAdmin); + const canManageClassification = classification.available && isSystemAdmin; + const [classificationEnabled, setClassificationEnabled] = useState(classificationBanner.hasClassification); + const [selectedClassificationId, setSelectedClassificationId] = useState(classificationBanner.classificationId || ''); + + const bannerLockedByClassification = classificationEnabled && Boolean(selectedClassificationId); + + useEffect(() => { + setClassificationEnabled(classificationBanner.hasClassification); + setSelectedClassificationId(classificationBanner.classificationId || ''); + + if (classificationBanner.hasClassification && classificationBanner.classificationBanner) { + setUpdatedChannelBanner((prev) => ({ + ...prev, + enabled: true, + text: classificationBanner.classificationBanner?.text ?? prev.text, + background_color: classificationBanner.classificationBanner?.background_color || prev.background_color || DEFAULT_CHANNEL_BANNER.background_color, + })); + } + }, [classificationBanner.hasClassification, classificationBanner.classificationId, classificationBanner.classificationBanner]); + + const classificationOptions = useMemo(() => { + return classification.levels. + filter((l) => l.name.trim() !== ''). + map((l) => ({value: l.id, label: l.name.trim(), color: l.color})); + }, [classification.levels]); + + const selectedClassificationOption = useMemo(() => { + return classificationOptions.find((o) => o.value === selectedClassificationId); + }, [classificationOptions, selectedClassificationId]); + + const formatClassificationOptionLabel = useCallback((option: ValueType) => { + const levelOption = option as ValueType & {color: string}; + return ( + + + {levelOption.label} + + ); + }, []); + + const selectedClassificationColor = useMemo((): string => { + const level = classification.levels.find((l) => l.id === selectedClassificationId); + return level?.color || ''; + }, [classification.levels, selectedClassificationId]); + + const initialClassificationState = useMemo(() => ({ + enabled: classificationBanner.hasClassification, + classificationId: classificationBanner.classificationId || '', + }), [classificationBanner.hasClassification, classificationBanner.classificationId]); + + const hasClassificationChanges = classificationEnabled !== initialClassificationState.enabled || + selectedClassificationId !== initialClassificationState.classificationId; + + const handleClassificationToggle = useCallback(() => { + setClassificationEnabled((prev) => { + if (!prev) { + const lowestRank = classification.levels[0]; + if (lowestRank) { + setSelectedClassificationId(lowestRank.id); + setUpdatedChannelBanner((banner) => ({ + ...banner, + enabled: true, + text: `**${lowestRank.name}**`, + background_color: lowestRank.color, + })); + } else { + setUpdatedChannelBanner((banner) => ({...banner, enabled: true})); + } + } + return !prev; + }); + }, [classification.levels]); + + const handleClassificationLevelChange = useCallback((selected: ValueType) => { + setSelectedClassificationId(selected.value); + const level = classification.levels.find((l) => l.id === selected.value); + if (level) { + setUpdatedChannelBanner((prev) => ({ + ...prev, + enabled: true, + text: `**${level.name}**`, + background_color: level.color, + })); + } + }, [classification.levels]); + const handleBannerToggle = useCallback(() => { const newValue = !updatedChannelBanner.enabled; const toUpdate = { @@ -250,6 +353,7 @@ function ChannelSettingsConfigurationTab({ // Common const hasUnsavedChanges = hasBannerChanges || hasAutoTranslationChanges || + hasClassificationChanges || (canManageSharedChannels && hasWorkspaceChanges); useEffect(() => { @@ -297,7 +401,15 @@ function ChannelSettingsConfigurationTab({ updated.autotranslation = isChannelAutotranslated; } - if (hasAutoTranslationChanges || hasBannerChanges) { + if (hasClassificationChanges && classificationEnabled && selectedClassificationId) { + updated.banner_info = { + text: updatedChannelBanner.text?.trim() || '', + background_color: updatedChannelBanner.background_color?.trim() || '', + enabled: true, + }; + } + + if (hasAutoTranslationChanges || hasBannerChanges || (hasClassificationChanges && classificationEnabled && selectedClassificationId)) { const {error} = await dispatch(patchChannel(channel.id, updated)); if (error) { handleServerError(error as ServerError); @@ -305,6 +417,36 @@ function ChannelSettingsConfigurationTab({ } } + if (hasClassificationChanges && classification.channelField) { + if (classificationEnabled && selectedClassificationId) { + try { + const values = await Client4.patchPropertyValues( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + channel.id, + [{field_id: classification.channelField.id, value: selectedClassificationId}], + ); + dispatch({type: PropertyTypes.RECEIVED_PROPERTY_VALUES, data: {values}}); + } catch (err) { + handleServerError(err as ServerError); + return false; + } + } else if (!classificationEnabled && initialClassificationState.enabled) { + try { + await Client4.patchPropertyValues( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + channel.id, + [{field_id: classification.channelField.id, value: null}], + ); + dispatch({type: PropertyTypes.PROPERTY_VALUE_DELETED, data: {targetId: channel.id, fieldId: classification.channelField.id}}); + } catch (err) { + handleServerError(err as ServerError); + return false; + } + } + } + if (canManageSharedChannels && hasWorkspaceChanges) { const initialIds = new Set((initialRemotes || []).map((r) => r.remote_id || r.name)); const currentIds = new Set(workspaceRemotes.map((r) => r.remote_id || r.name)); @@ -352,16 +494,21 @@ function ChannelSettingsConfigurationTab({ }, [ canManageSharedChannels, channel, + classification.channelField, + classificationEnabled, dispatch, formatMessage, handleServerError, hasAutoTranslationChanges, hasBannerChanges, + hasClassificationChanges, hasWorkspaceChanges, initialBannerInfo, + initialClassificationState.enabled, initialIsChannelAutotranslated, initialRemotes, isChannelAutotranslated, + selectedClassificationId, updatedChannelBanner, workspaceRemotes, ]); @@ -414,6 +561,10 @@ function ChannelSettingsConfigurationTab({ setFormError(''); setSaveChangesPanelState(undefined); setCharacterLimitExceeded(false); + + setClassificationEnabled(initialClassificationState.enabled); + setSelectedClassificationId(initialClassificationState.classificationId); + if (canManageSharedChannels) { setSharingEnabled(initialSharingEnabled.current); if (initialRemotes) { @@ -421,19 +572,21 @@ function ChannelSettingsConfigurationTab({ setShareChannelKey(Date.now()); } } - }, [canManageSharedChannels, initialBannerInfo, initialRemotes]); + }, [canManageSharedChannels, initialBannerInfo, initialClassificationState, initialRemotes]); const handleClose = useCallback(() => { setSaveChangesPanelState(undefined); setRequireConfirm(false); }, []); + const classificationFormInvalid = classificationEnabled && !selectedClassificationId; const hasErrors = Boolean(formError) || characterLimitExceeded || + classificationFormInvalid || showTabSwitchError; return ( -
+
{canManageSharedChannels && ( <> )} - {canManageSharedChannels && canManageBanner && ( + {canManageSharedChannels && (canManageClassification || canManageBanner) && ( +
+ )} + + {canManageClassification && ( + <> +
+
+ + + + + + +
+ +
+ +
+
+ + {classificationEnabled && ( +
+
+ + } + text={formatMessage({id: 'admin.classification_markings.notice.body', defaultMessage: 'Markings are not tied to access control decisions at this time and are for display purposes only.'})} + /> +
+ +
+ + + +
+ +
+
+
+ )} + + )} + + {canManageClassification && canManageBanner && (
)} @@ -491,9 +723,9 @@ function ChannelSettingsConfigurationTab({ id='channelBannerToggle' ariaLabel={bannerHeading} size='btn-md' - disabled={false} + disabled={bannerLockedByClassification} onToggle={handleBannerToggle} - toggled={updatedChannelBanner.enabled} + toggled={bannerLockedByClassification || updatedChannelBanner.enabled} tabIndex={0} toggleClassName='btn-toggle-primary' /> @@ -501,7 +733,7 @@ function ChannelSettingsConfigurationTab({
{ - updatedChannelBanner.enabled && + (bannerLockedByClassification || updatedChannelBanner.enabled) &&
{/*Banner text section*/}
@@ -544,7 +776,8 @@ function ChannelSettingsConfigurationTab({
diff --git a/webapp/channels/src/components/channel_settings_modal/channel_settings_modal.scss b/webapp/channels/src/components/channel_settings_modal/channel_settings_modal.scss index 4cb085aca46..9a6b22e61bc 100644 --- a/webapp/channels/src/components/channel_settings_modal/channel_settings_modal.scss +++ b/webapp/channels/src/components/channel_settings_modal/channel_settings_modal.scss @@ -13,7 +13,7 @@ line-height: 20px; } - label.Input_subheading { + .Input_subheading { color: rgba(var(--center-channel-color-rgb), 0.64); font-family: Open Sans, sans-serif; font-size: 12px; @@ -30,10 +30,15 @@ border-radius: var(--radius-l); box-shadow: var(--elevation-6); + .modal-header { + flex-shrink: 0; + } + .modal-body { display: flex; width: auto; - min-height: 150px; + min-height: 0; + flex: 1 1 auto; flex-direction: column; margin: 0; gap: 24px; diff --git a/webapp/channels/src/components/color_input.tsx b/webapp/channels/src/components/color_input.tsx index c8764b955d8..a02d9250b23 100644 --- a/webapp/channels/src/components/color_input.tsx +++ b/webapp/channels/src/components/color_input.tsx @@ -122,23 +122,21 @@ const ColorInput = ({ disabled={isDisabled} data-testid='color-inputColorValue' /> - {!isDisabled && - - - - } + + + {isOpened && (
({ + __esModule: true, + ...jest.requireActual('react-redux'), +})); + +const CHANNEL_ID = 'channel_id_1'; +const FIELD_ID = 'channel_field_1'; + +function makeChannelField(overrides: Partial = {}): PropertyField { + return { + id: FIELD_ID, + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_CHANNEL_FIELD_NAME, + type: 'select', + attrs: {}, + target_id: '', + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + linked_field_id: 'template1', + create_at: 1000, + update_at: 1000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + ...overrides, + }; +} + +function makePropertyValue(value: string | null): PropertyValue { + return { + id: 'value1', + target_id: CHANNEL_ID, + target_type: CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + group_id: CLASSIFICATIONS_GROUP_NAME, + field_id: FIELD_ID, + value: value as string, + create_at: 2000, + update_at: 2000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + }; +} + +const SAMPLE_LEVELS: ClassificationLevel[] = [ + {id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}, + {id: 'lvl2', name: 'SECRET', color: '#C8102E', rank: 2}, +]; + +type PartialState = Parameters[1]; + +function stateWithValue( + value: PropertyValue | undefined, + bannerInfo?: {enabled?: boolean; text?: string; background_color?: string}, +): PartialState { + return { + entities: { + channels: { + channels: { + [CHANNEL_ID]: { + id: CHANNEL_ID, + banner_info: bannerInfo, + }, + }, + }, + properties: { + values: { + byTargetId: value ? {[CHANNEL_ID]: {[FIELD_ID]: value}} : {}, + }, + }, + }, + } as PartialState; +} + +function mockClassification(overrides: Partial = {}) { + return jest.spyOn(ClassificationHook, 'default').mockReturnValue({ + available: true, + loading: false, + templateField: null, + channelField: makeChannelField(), + levels: SAMPLE_LEVELS, + ...overrides, + }); +} + +describe('useChannelClassificationBanner', () => { + const dispatchMock = jest.fn(); + + beforeAll(() => { + jest.spyOn(ReactRedux, 'useDispatch').mockImplementation(() => dispatchMock); + }); + + afterAll(() => { + jest.restoreAllMocks(); + }); + + beforeEach(() => { + dispatchMock.mockClear(); + jest.spyOn(Client4, 'getPropertyValues').mockResolvedValue([]); + }); + + afterEach(() => { + jest.restoreAllMocks(); + jest.spyOn(ReactRedux, 'useDispatch').mockImplementation(() => dispatchMock); + }); + + test('returns hasClassification=false when no property value exists for the channel', () => { + mockClassification(); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(undefined), + ); + + expect(result.current.hasClassification).toBe(false); + expect(result.current.classificationBanner).toBeUndefined(); + expect(result.current.classificationId).toBeUndefined(); + expect(result.current.bannerText).toBeUndefined(); + }); + + test('returns hasClassification=false when property value contains null value', () => { + mockClassification(); + const value = makePropertyValue(null); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(value), + ); + + expect(result.current.hasClassification).toBe(false); + expect(result.current.classificationBanner).toBeUndefined(); + }); + + test('maps a valid string classification_id to the matching level banner shape with text from banner_info', () => { + mockClassification(); + const value = makePropertyValue('lvl2'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(value, {enabled: true, text: '**SECRET**', background_color: '#C8102E'}), + ); + + expect(result.current.hasClassification).toBe(true); + expect(result.current.classificationId).toBe('lvl2'); + expect(result.current.bannerText).toBe('**SECRET**'); + expect(result.current.classificationBanner).toEqual({ + enabled: true, + text: '**SECRET**', + background_color: '#C8102E', + }); + }); + + test('falls back to level name when banner_info.text is missing but classification is set', () => { + mockClassification(); + const value = makePropertyValue('lvl1'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(value), + ); + + expect(result.current.hasClassification).toBe(true); + expect(result.current.classificationId).toBe('lvl1'); + expect(result.current.bannerText).toBe('**UNCLASSIFIED**'); + expect(result.current.classificationBanner).toEqual({ + enabled: true, + text: '**UNCLASSIFIED**', + background_color: '#007A33', + }); + }); + + test('returns hasClassification=false when the referenced level no longer exists', () => { + mockClassification(); + const value = makePropertyValue('deleted_lvl'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(value), + ); + + expect(result.current.hasClassification).toBe(false); + expect(result.current.classificationBanner).toBeUndefined(); + }); + + test('returns hasClassification=false for legacy object-shaped property values', () => { + mockClassification(); + + // Simulate a pre-migration object-shaped value that should be treated as invalid + const legacyValue = { + id: 'value1', + target_id: CHANNEL_ID, + target_type: CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + group_id: CLASSIFICATIONS_GROUP_NAME, + field_id: FIELD_ID, + value: {classification_id: 'lvl1', banner_text: 'test'} as unknown as string, + create_at: 2000, + update_at: 2000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + }; + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(legacyValue as PropertyValue), + ); + + expect(result.current.hasClassification).toBe(false); + expect(result.current.classificationBanner).toBeUndefined(); + }); + + test('returns empty state and skips fetching when channelField is missing', () => { + mockClassification({channelField: null, available: false}); + + const fetchSpy = jest.spyOn(Client4, 'getPropertyValues'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(undefined), + ); + + expect(result.current.hasClassification).toBe(false); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + test('returns empty state when classification is unavailable (feature flag/license off)', () => { + mockClassification({available: false, channelField: makeChannelField(), levels: []}); + + const fetchSpy = jest.spyOn(Client4, 'getPropertyValues'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(undefined), + ); + + expect(result.current.hasClassification).toBe(false); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + test('does not attempt to fetch when channelId is empty', () => { + mockClassification(); + const fetchSpy = jest.spyOn(Client4, 'getPropertyValues'); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(''), + stateWithValue(undefined), + ); + + expect(result.current.hasClassification).toBe(false); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + test('fetches property values when none exist and classification is available', async () => { + mockClassification(); + const fetchSpy = jest.spyOn(Client4, 'getPropertyValues').mockResolvedValue([]); + + renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(undefined), + ); + + await Promise.resolve(); + expect(fetchSpy).toHaveBeenCalledWith(CLASSIFICATIONS_GROUP_NAME, CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, CHANNEL_ID); + }); + + test('silently ignores fetch errors (channel may not have classification set)', async () => { + mockClassification(); + jest.spyOn(Client4, 'getPropertyValues').mockRejectedValue(new Error('404')); + + const {result} = renderHookWithContext( + () => useChannelClassificationBanner(CHANNEL_ID), + stateWithValue(undefined), + ); + + await Promise.resolve(); + expect(result.current.hasClassification).toBe(false); + }); +}); diff --git a/webapp/channels/src/components/common/hooks/useChannelClassificationBanner.ts b/webapp/channels/src/components/common/hooks/useChannelClassificationBanner.ts new file mode 100644 index 00000000000..cc757b8d5c1 --- /dev/null +++ b/webapp/channels/src/components/common/hooks/useChannelClassificationBanner.ts @@ -0,0 +1,113 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {useEffect, useMemo} from 'react'; +import {useDispatch, useSelector} from 'react-redux'; + +import type {ChannelBanner} from '@mattermost/types/channels'; +import type {PropertyValue} from '@mattermost/types/properties'; +import type {GlobalState} from '@mattermost/types/store'; + +import {PropertyTypes} from 'mattermost-redux/action_types'; +import {Client4} from 'mattermost-redux/client'; +import {getChannelBanner} from 'mattermost-redux/selectors/entities/channels'; +import {getPropertyValueForTargetField} from 'mattermost-redux/selectors/entities/properties'; + +import { + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, +} from 'components/admin_console/classification_markings/utils'; + +import useClassificationMarkings from './useClassificationMarkings'; + +export type ChannelClassificationBannerState = { + hasClassification: boolean; + classificationBanner: ChannelBanner | undefined; + classificationId: string | undefined; + bannerText: string | undefined; +}; + +/** + * Resolves the effective banner display for a channel by checking whether a + * classification property value exists. If one does, its color (from the level + * definition) and text (from the channel's banner_info) take priority over + * the channel's native banner_info. + * + * The PropertyValue stores only the classification_id (a plain string). + * The banner text lives in channel.banner_info.text so that the property + * value stays a single scalar. + */ +export default function useChannelClassificationBanner(channelId: string): ChannelClassificationBannerState { + const dispatch = useDispatch(); + const classification = useClassificationMarkings(); + + const fieldId = classification.channelField?.id ?? ''; + + const propertyValue = useSelector((state: GlobalState) => { + if (!fieldId || !channelId) { + return undefined; + } + return getPropertyValueForTargetField(state, channelId, fieldId) as PropertyValue | undefined; + }); + + const channelBannerInfo = useSelector((state: GlobalState) => getChannelBanner(state, channelId)); + + useEffect(() => { + if (!channelId || !classification.available || !classification.channelField) { + return; + } + + if (!propertyValue) { + Client4.getPropertyValues( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + channelId, + ).then((values) => { + if (values && values.length > 0) { + dispatch({ + type: PropertyTypes.RECEIVED_PROPERTY_VALUES, + data: {values}, + }); + } + }).catch(() => { + // Silently ignore - channel may not have a classification set + }); + } + }, [channelId, classification.available, classification.channelField, propertyValue, dispatch]); + + return useMemo((): ChannelClassificationBannerState => { + const noClassification: ChannelClassificationBannerState = { + hasClassification: false, + classificationBanner: undefined, + classificationId: undefined, + bannerText: undefined, + }; + + if (!propertyValue || !propertyValue.value) { + return noClassification; + } + + const classificationId = propertyValue.value; + if (typeof classificationId !== 'string') { + return noClassification; + } + + const level = classification.levels.find((l) => l.id === classificationId); + if (!level) { + return noClassification; + } + + const bannerText = channelBannerInfo?.text ?? `**${level.name}**`; + + return { + hasClassification: true, + classificationBanner: { + enabled: true, + text: bannerText, + background_color: level.color, + }, + classificationId, + bannerText, + }; + }, [propertyValue, classification.levels, channelBannerInfo]); +} diff --git a/webapp/channels/src/components/common/hooks/useClassificationMarkings.test.ts b/webapp/channels/src/components/common/hooks/useClassificationMarkings.test.ts new file mode 100644 index 00000000000..bd7557dccfa --- /dev/null +++ b/webapp/channels/src/components/common/hooks/useClassificationMarkings.test.ts @@ -0,0 +1,298 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import * as ReactRedux from 'react-redux'; + +import type {PropertyField} from '@mattermost/types/properties'; +import type {GlobalState} from '@mattermost/types/store'; + +import { + CLASSIFICATIONS_CHANNEL_FIELD_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_FIELD_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, +} from 'components/admin_console/classification_markings/utils'; + +import {renderHookWithContext} from 'tests/react_testing_utils'; + +import useClassificationMarkings, {selectClassificationTemplateField} from './useClassificationMarkings'; + +type PartialState = Parameters[1]; + +jest.mock('react-redux', () => ({ + __esModule: true, + ...jest.requireActual('react-redux'), +})); + +function makeTemplateField(overrides: Partial = {}): PropertyField { + return { + id: 'template1', + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_TEMPLATE_FIELD_NAME, + type: 'select', + attrs: {options: [{id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}]}, + target_id: '', + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + create_at: 1000, + update_at: 1000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + ...overrides, + }; +} + +function makeChannelField(overrides: Partial = {}): PropertyField { + return { + id: 'channel1', + group_id: CLASSIFICATIONS_GROUP_NAME, + name: CLASSIFICATIONS_CHANNEL_FIELD_NAME, + type: 'select', + attrs: {}, + target_id: '', + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + object_type: CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + linked_field_id: 'template1', + create_at: 2000, + update_at: 2000, + delete_at: 0, + created_by: 'user1', + updated_by: 'user1', + ...overrides, + }; +} + +const ENTERPRISE_LICENSE = {IsLicensed: 'true', SkuShortName: 'enterprise'}; +const STARTER_LICENSE = {IsLicensed: 'true', SkuShortName: 'starter'}; + +function stateWith({featureFlag, license, fields = {}}: { + featureFlag?: string; + license?: typeof ENTERPRISE_LICENSE | typeof STARTER_LICENSE | Record; + fields?: Record; +}): PartialState { + return { + entities: { + general: { + config: featureFlag === undefined ? {} : {FeatureFlagClassificationMarkings: featureFlag}, + license: license ?? {}, + }, + properties: { + fields: {byId: fields}, + }, + }, + } as PartialState; +} + +describe('useClassificationMarkings', () => { + const dispatchMock = jest.fn(); + + beforeAll(() => { + jest.spyOn(ReactRedux, 'useDispatch').mockImplementation(() => dispatchMock); + }); + + afterAll(() => { + jest.restoreAllMocks(); + }); + + beforeEach(() => { + dispatchMock.mockClear(); + }); + + test('returns available=false when feature flag is disabled', () => { + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'false', license: ENTERPRISE_LICENSE}), + ); + + expect(result.current.available).toBe(false); + expect(result.current.loading).toBe(false); + expect(result.current.levels).toEqual([]); + expect(dispatchMock).not.toHaveBeenCalled(); + }); + + test('returns available=false when feature flag is missing from config', () => { + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({license: ENTERPRISE_LICENSE}), + ); + + expect(result.current.available).toBe(false); + expect(result.current.loading).toBe(false); + expect(dispatchMock).not.toHaveBeenCalled(); + }); + + test('returns available=false when license is not Enterprise', () => { + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'true', license: STARTER_LICENSE}), + ); + + expect(result.current.available).toBe(false); + expect(result.current.loading).toBe(false); + expect(dispatchMock).not.toHaveBeenCalled(); + }); + + test('returns available=false when license is missing entirely', () => { + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'true', license: {}}), + ); + + expect(result.current.available).toBe(false); + expect(dispatchMock).not.toHaveBeenCalled(); + }); + + test('returns loading=true and dispatches fetches when flag and license are on but no fields are loaded', () => { + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'true', license: ENTERPRISE_LICENSE}), + ); + + expect(result.current.loading).toBe(true); + expect(result.current.available).toBe(false); + expect(result.current.templateField).toBeNull(); + expect(result.current.channelField).toBeNull(); + expect(result.current.levels).toEqual([]); + + // The hook dispatches one fetch for the template field and one for the channel field. + expect(dispatchMock).toHaveBeenCalledTimes(2); + }); + + test('returns available=true and derives levels when template field is loaded', () => { + const template = makeTemplateField({ + attrs: { + options: [ + {id: 'lvl2', name: 'SECRET', color: '#C8102E', rank: 2}, + {id: 'lvl1', name: 'UNCLASSIFIED', color: '#007A33', rank: 1}, + ], + }, + }); + + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'true', license: ENTERPRISE_LICENSE, fields: {template1: template}}), + ); + + expect(result.current.available).toBe(true); + expect(result.current.loading).toBe(false); + expect(result.current.templateField).toBe(template); + expect(result.current.levels).toHaveLength(2); + + // Levels are sorted by rank ascending. + expect(result.current.levels[0].name).toBe('UNCLASSIFIED'); + expect(result.current.levels[1].name).toBe('SECRET'); + }); + + test('returns available=false when template field exists but has no levels', () => { + const template = makeTemplateField({attrs: {options: []}}); + + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({featureFlag: 'true', license: ENTERPRISE_LICENSE, fields: {template1: template}}), + ); + + expect(result.current.available).toBe(false); + expect(result.current.loading).toBe(false); + expect(result.current.templateField).toBe(template); + expect(result.current.levels).toEqual([]); + }); + + test('exposes channelField when it exists in the store', () => { + const template = makeTemplateField(); + const channel = makeChannelField(); + + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({ + featureFlag: 'true', + license: ENTERPRISE_LICENSE, + fields: {template1: template, channel1: channel}, + }), + ); + + expect(result.current.channelField).toBe(channel); + expect(result.current.templateField).toBe(template); + }); + + test('returns channelField=null when channel-linked field is missing linked_field_id', () => { + const template = makeTemplateField(); + const orphan = makeChannelField({id: 'orphan', linked_field_id: ''}); + + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({ + featureFlag: 'true', + license: ENTERPRISE_LICENSE, + fields: {template1: template, orphan}, + }), + ); + + expect(result.current.channelField).toBeNull(); + }); + + test('returns channelField=null when channel-linked field is soft-deleted', () => { + const template = makeTemplateField(); + const deleted = makeChannelField({delete_at: 9999}); + + const {result} = renderHookWithContext( + () => useClassificationMarkings(), + stateWith({ + featureFlag: 'true', + license: ENTERPRISE_LICENSE, + fields: {template1: template, channel1: deleted}, + }), + ); + + expect(result.current.channelField).toBeNull(); + }); + + test('does not dispatch fetch when both fields are already in the store', () => { + const template = makeTemplateField(); + const channel = makeChannelField(); + + renderHookWithContext( + () => useClassificationMarkings(), + stateWith({ + featureFlag: 'true', + license: ENTERPRISE_LICENSE, + fields: {template1: template, channel1: channel}, + }), + ); + + expect(dispatchMock).not.toHaveBeenCalled(); + }); +}); + +describe('selectClassificationTemplateField', () => { + function fullState(fields: Record = {}): GlobalState { + return stateWith({fields}) as unknown as GlobalState; + } + + test('returns undefined when properties store is empty', () => { + expect(selectClassificationTemplateField(fullState())).toBeUndefined(); + }); + + test('returns undefined when properties.fields.byId is missing', () => { + const state = {entities: {properties: {fields: {}}}} as unknown as GlobalState; + expect(selectClassificationTemplateField(state)).toBeUndefined(); + }); + + test('returns the matching template field by name and object_type', () => { + const template = makeTemplateField(); + expect(selectClassificationTemplateField(fullState({template1: template}))).toBe(template); + }); + + test('ignores soft-deleted template fields', () => { + const deleted = makeTemplateField({delete_at: 9999}); + expect(selectClassificationTemplateField(fullState({template1: deleted}))).toBeUndefined(); + }); + + test('ignores fields with different object_type or name', () => { + const wrongName = makeTemplateField({name: 'something_else'}); + const wrongType = makeTemplateField({id: 'other', object_type: 'system'}); + expect(selectClassificationTemplateField(fullState({wrongName, wrongType}))).toBeUndefined(); + }); +}); diff --git a/webapp/channels/src/components/common/hooks/useClassificationMarkings.ts b/webapp/channels/src/components/common/hooks/useClassificationMarkings.ts new file mode 100644 index 00000000000..d2264b0349a --- /dev/null +++ b/webapp/channels/src/components/common/hooks/useClassificationMarkings.ts @@ -0,0 +1,110 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {useEffect, useMemo} from 'react'; +import {useDispatch, useSelector} from 'react-redux'; + +import type {PropertyField, PropertyFieldOption} from '@mattermost/types/properties'; +import type {GlobalState} from '@mattermost/types/store'; + +import {fetchPropertyFields} from 'mattermost-redux/actions/properties'; +import {getFeatureFlagValue, getLicense} from 'mattermost-redux/selectors/entities/general'; + +import { + CLASSIFICATIONS_CHANNEL_FIELD_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_FIELD_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + optionsToLevels, +} from 'components/admin_console/classification_markings/utils'; +import type {ClassificationLevel} from 'components/admin_console/classification_markings/utils/presets'; + +import {isEnterpriseLicense} from 'utils/license_utils'; + +export function selectClassificationTemplateField(state: GlobalState): PropertyField | undefined { + const byId = state.entities.properties?.fields?.byId; + if (!byId) { + return undefined; + } + return Object.values(byId).find( + (f) => f.object_type === CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE && f.name === CLASSIFICATIONS_TEMPLATE_FIELD_NAME && f.delete_at === 0, + ); +} + +function selectChannelClassificationField(state: GlobalState): PropertyField | undefined { + const byId = state.entities.properties?.fields?.byId; + if (!byId) { + return undefined; + } + return Object.values(byId).find( + (f) => f.object_type === CLASSIFICATIONS_CHANNEL_OBJECT_TYPE && f.name === CLASSIFICATIONS_CHANNEL_FIELD_NAME && f.linked_field_id && f.delete_at === 0, + ); +} + +export type ClassificationMarkingsState = { + available: boolean; + loading: boolean; + templateField: PropertyField | null; + channelField: PropertyField | null; + levels: ClassificationLevel[]; +}; + +/** + * Reusable hook that gates classification markings availability. + * Returns available=true only when all 3 conditions are met: + * 1. ClassificationMarkings feature flag is enabled + * 2. Enterprise license is active + * 3. Template classification field exists with at least one level configured + * + * Also fetches the channel_classification linked field for consumers that need it. + */ +export default function useClassificationMarkings(): ClassificationMarkingsState { + const dispatch = useDispatch(); + + const featureEnabled = useSelector( + (state: GlobalState) => getFeatureFlagValue(state, 'ClassificationMarkings') === 'true', + ); + const license = useSelector(getLicense); + const hasEnterpriseLicense = isEnterpriseLicense(license); + const templateField = useSelector(selectClassificationTemplateField) ?? null; + const channelField = useSelector(selectChannelClassificationField) ?? null; + + useEffect(() => { + if (!featureEnabled || !hasEnterpriseLicense) { + return; + } + if (!templateField) { + dispatch(fetchPropertyFields( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + )); + } + if (!channelField) { + dispatch(fetchPropertyFields( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + )); + } + }, [featureEnabled, hasEnterpriseLicense, templateField, channelField, dispatch]); + + const levels = useMemo((): ClassificationLevel[] => { + if (!templateField) { + return []; + } + const options = (templateField.attrs?.options as PropertyFieldOption[]) || []; + return optionsToLevels(options); + }, [templateField]); + + const loading = featureEnabled && hasEnterpriseLicense && !templateField; + + const available = featureEnabled && hasEnterpriseLicense && levels.length > 0; + + return {available, loading, templateField, channelField, levels}; +} diff --git a/webapp/channels/src/components/global_classification_banner/global_classification_banner.test.tsx b/webapp/channels/src/components/global_classification_banner/global_classification_banner.test.tsx index 74ae87c750c..77270c20d90 100644 --- a/webapp/channels/src/components/global_classification_banner/global_classification_banner.test.tsx +++ b/webapp/channels/src/components/global_classification_banner/global_classification_banner.test.tsx @@ -11,12 +11,12 @@ import {Client4} from 'mattermost-redux/client'; import { DISPLAY_BANNER_BOTTOM, DISPLAY_BANNER_TOP, - GROUP_NAME, - LINKED_OBJECT_TYPE, - OBJECT_TYPE, - SYSTEM_FIELD_TARGET_ID, - SYSTEM_VALUE_TARGET_ID, - TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, } from 'components/admin_console/classification_markings/utils'; import {renderWithContext, screen} from 'tests/react_testing_utils'; @@ -35,11 +35,11 @@ const LINKED_FIELD_ID = 'linked_field1'; function makeTemplateField(options: Array<{id: string; name: string; color: string}>): PropertyField { return { id: TEMPLATE_FIELD_ID, - group_id: GROUP_NAME, + group_id: CLASSIFICATIONS_GROUP_NAME, name: 'classification', type: 'select', - object_type: OBJECT_TYPE, - target_type: TARGET_TYPE, + object_type: CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, target_id: '', create_at: 1000, update_at: 1000, @@ -56,12 +56,12 @@ function makeTemplateField(options: Array<{id: string; name: string; color: stri function makeLinkedField(actions: string[], options: Array<{id: string; name: string; color: string}> = []): PropertyField { return { id: LINKED_FIELD_ID, - group_id: GROUP_NAME, + group_id: CLASSIFICATIONS_GROUP_NAME, name: 'system_classification', type: 'select', - object_type: LINKED_OBJECT_TYPE, - target_type: TARGET_TYPE, - target_id: SYSTEM_FIELD_TARGET_ID, + object_type: CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + target_type: CLASSIFICATIONS_FIELD_TARGET_TYPE, + target_id: CLASSIFICATIONS_FIELD_TARGET_ID, linked_field_id: TEMPLATE_FIELD_ID, create_at: 2000, update_at: 2000, @@ -78,9 +78,9 @@ function makeLinkedField(actions: string[], options: Array<{id: string; name: st function makeSystemValue(optionId: string): PropertyValue { return { id: 'value1', - target_id: SYSTEM_VALUE_TARGET_ID, - target_type: LINKED_OBJECT_TYPE, - group_id: GROUP_NAME, + target_id: CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID, + target_type: CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + group_id: CLASSIFICATIONS_GROUP_NAME, field_id: LINKED_FIELD_ID, value: optionId, create_at: 3000, @@ -317,10 +317,10 @@ describe('GlobalClassificationBanner', () => { ); expect(Client4.getPropertyFields).toHaveBeenCalledWith( - GROUP_NAME, - LINKED_OBJECT_TYPE, - TARGET_TYPE, - SYSTEM_FIELD_TARGET_ID, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, expect.anything(), ); }); diff --git a/webapp/channels/src/components/global_classification_banner/global_classification_banner.tsx b/webapp/channels/src/components/global_classification_banner/global_classification_banner.tsx index 4ed33322b7c..a7d81a46f66 100644 --- a/webapp/channels/src/components/global_classification_banner/global_classification_banner.tsx +++ b/webapp/channels/src/components/global_classification_banner/global_classification_banner.tsx @@ -13,19 +13,18 @@ import {getPropertyValueForTargetField} from 'mattermost-redux/selectors/entitie import {getContrastingSimpleColor} from 'mattermost-redux/utils/theme_utils'; import { + CLASSIFICATIONS_FIELD_TARGET_ID, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_FIELD_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, DISPLAY_BANNER_BOTTOM, DISPLAY_BANNER_TOP, - FIELD_NAME, - GROUP_NAME, - LINKED_FIELD_NAME, - LINKED_OBJECT_TYPE, - OBJECT_TYPE, - SYSTEM_FIELD_TARGET_ID, - SYSTEM_VALUE_TARGET_ID, - TARGET_ID, - TARGET_TYPE, findOptionById, } from 'components/admin_console/classification_markings/utils'; +import {selectClassificationTemplateField} from 'components/common/hooks/useClassificationMarkings'; import './global_classification_banner.scss'; @@ -35,16 +34,6 @@ type Props = { position: 'top' | 'bottom'; }; -function selectClassificationTemplateField(state: GlobalState): PropertyField | undefined { - const byId = state.entities.properties?.fields?.byId; - if (!byId) { - return undefined; - } - return Object.values(byId).find( - (f) => f.object_type === OBJECT_TYPE && f.name === FIELD_NAME && f.delete_at === 0, - ); -} - function selectLinkedSystemField(state: GlobalState): PropertyField | undefined { const byId = state.entities.properties?.fields?.byId; if (!byId) { @@ -53,7 +42,7 @@ function selectLinkedSystemField(state: GlobalState): PropertyField | undefined // The linked system field has object_type 'system' and a linked_field_id set. return Object.values(byId).find( - (f) => f.object_type === LINKED_OBJECT_TYPE && f.name === LINKED_FIELD_NAME && f.linked_field_id && f.delete_at === 0, + (f) => f.object_type === CLASSIFICATIONS_SYSTEM_OBJECT_TYPE && f.name === CLASSIFICATIONS_SYSTEM_FIELD_NAME && f.linked_field_id && f.delete_at === 0, ); } @@ -66,7 +55,7 @@ export default function GlobalClassificationBanner({position}: Props) { if (!linkedField) { return undefined; } - return getPropertyValueForTargetField(state, SYSTEM_VALUE_TARGET_ID, linkedField.id) as PropertyValue | undefined; + return getPropertyValueForTargetField(state, CLASSIFICATIONS_SYSTEM_VALUE_TARGET_ID, linkedField.id) as PropertyValue | undefined; }); // Bootstrap: fetch template fields, the linked system field, and system property values. @@ -80,13 +69,23 @@ export default function GlobalClassificationBanner({position}: Props) { return; } if (!templateField) { - dispatch(fetchPropertyFields(GROUP_NAME, OBJECT_TYPE, TARGET_TYPE, TARGET_ID)); + dispatch(fetchPropertyFields( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_TEMPLATE_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + )); } if (!linkedField) { - dispatch(fetchPropertyFields(GROUP_NAME, LINKED_OBJECT_TYPE, TARGET_TYPE, SYSTEM_FIELD_TARGET_ID)); + dispatch(fetchPropertyFields( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_SYSTEM_OBJECT_TYPE, + CLASSIFICATIONS_FIELD_TARGET_TYPE, + CLASSIFICATIONS_FIELD_TARGET_ID, + )); } if (linkedField && !systemValue) { - dispatch(fetchSystemPropertyValues(GROUP_NAME)); + dispatch(fetchSystemPropertyValues(CLASSIFICATIONS_GROUP_NAME)); } }, [featureEnabled, templateField, linkedField, systemValue, dispatch]); diff --git a/webapp/channels/src/components/new_channel_modal/new_channel_modal.scss b/webapp/channels/src/components/new_channel_modal/new_channel_modal.scss index ad60b6fc00d..c50702ff858 100644 --- a/webapp/channels/src/components/new_channel_modal/new_channel_modal.scss +++ b/webapp/channels/src/components/new_channel_modal/new_channel_modal.scss @@ -1,8 +1,124 @@ .new-channel-modal { + .modal-content { + display: flex; + overflow: hidden; + flex-direction: column; + } + + .GenericModal__wrapper { + display: flex; + overflow: hidden; + min-height: 0; + flex: 1 1 auto; + flex-direction: column; + } + + .modal-header { + flex-shrink: 0; + } + + .modal-body { + min-height: 0; + flex: 1 1 auto; + overflow-y: auto; + } + + .modal-footer { + flex-shrink: 0; + } + + .new-channel-modal-classification__fields .DropdownInput.Input_container { + margin-top: 0; + + .Input_fieldset { + padding: 0; + border: none; + box-shadow: none; + + &:hover, + &:focus-within { + border: none; + box-shadow: none; + } + } + + .Input_wrapper { + padding: 0; + margin: 0; + } + + .DropdownInput__indicatorsContainer { + margin-right: 0; + } + } + .new-channel-modal-type-selector { margin-top: 24px; } + .new-channel-modal-classification { + padding-top: 16px; + border-top: 1px solid rgba(var(--center-channel-color-rgb), 0.08); + margin-top: 24px; + + &__header { + display: flex; + align-items: center; + justify-content: space-between; + + h4 { + margin: 0; + font-size: 14px; + font-weight: 600; + line-height: 20px; + } + } + + &__description { + margin: 4px 0 0; + color: rgba(var(--center-channel-color-rgb), 0.72); + font-size: 12px; + line-height: 16px; + } + + &__fields { + margin-top: 16px; + } + + &__field-row { + display: flex; + align-items: center; + gap: 16px; + + &:first-child { + margin-bottom: 16px; + } + } + + &__field-label { + width: 140px; + flex-shrink: 0; + margin: 0; + color: var(--center-channel-color); + font-size: 14px; + font-weight: 600; + line-height: 20px; + } + + &__field-input { + min-width: 0; + flex: 1; + + .AdvancedTextbox #PreviewInputTextButton { + position: absolute; + z-index: 10; + top: 7px; + right: 7px; + } + } + + } + .new-channel-modal-purpose-container { margin-top: 28px; diff --git a/webapp/channels/src/components/new_channel_modal/new_channel_modal.tsx b/webapp/channels/src/components/new_channel_modal/new_channel_modal.tsx index b9ecc8ec33b..8e0a62c924e 100644 --- a/webapp/channels/src/components/new_channel_modal/new_channel_modal.tsx +++ b/webapp/channels/src/components/new_channel_modal/new_channel_modal.tsx @@ -2,7 +2,7 @@ // See LICENSE.txt for license information. import classNames from 'classnames'; -import React, {useCallback, useRef, useState} from 'react'; +import React, {useCallback, useMemo, useRef, useState} from 'react'; import {FormattedMessage, useIntl} from 'react-intl'; import {useDispatch, useSelector} from 'react-redux'; @@ -14,18 +14,36 @@ import type {ServerError} from '@mattermost/types/errors'; import {setNewChannelWithBoardPreference} from 'mattermost-redux/actions/boards'; import {createChannel} from 'mattermost-redux/actions/channels'; +import {Client4} from 'mattermost-redux/client'; import Permissions from 'mattermost-redux/constants/permissions'; import Preferences from 'mattermost-redux/constants/preferences'; import {areManagedCategoriesEnabled, isChannelCategorySortingEnabled, makeGetSidebarCategoryNamesForTeam} from 'mattermost-redux/selectors/entities/channel_categories'; import {get as getPreference} from 'mattermost-redux/selectors/entities/preferences'; import {haveICurrentChannelPermission} from 'mattermost-redux/selectors/entities/roles'; import {getCurrentTeam} from 'mattermost-redux/selectors/entities/teams'; +import {isCurrentUserSystemAdmin} from 'mattermost-redux/selectors/entities/users'; import {switchToChannel} from 'actions/views/channel'; import {closeModal} from 'actions/views/modals'; +import {ColorSwatch, LevelOptionLabel} from 'components/admin_console/classification_markings/classification_markings_styled'; +import { + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + CLASSIFICATIONS_GROUP_NAME, +} from 'components/admin_console/classification_markings/utils'; +import {classificationPresetDropdownStyles} from 'components/admin_console/classification_markings/utils/preset_dropdown_styles'; import CategorySelector from 'components/category_selector/category_selector'; import ChannelNameFormField from 'components/channel_name_form_field/channel_name_form_field'; +import { + CHANNEL_BANNER_MAX_CHARACTER_LIMIT, + CHANNEL_BANNER_MIN_CHARACTER_LIMIT, +} from 'components/channel_settings_modal/channel_settings_configuration_tab'; +import useClassificationMarkings from 'components/common/hooks/useClassificationMarkings'; +import DropdownInput from 'components/dropdown_input'; +import type {ValueType} from 'components/dropdown_input'; +import type {TextboxElement} from 'components/textbox'; +import Toggle from 'components/toggle'; +import AdvancedTextbox from 'components/widgets/advanced_textbox/advanced_textbox'; import Input from 'components/widgets/inputs/input/input'; import PublicPrivateSelector from 'components/widgets/public-private-selector/public-private-selector'; @@ -82,6 +100,46 @@ const NewChannelModal = () => { const [defaultCategoryName, setDefaultCategoryName] = useState(undefined); const [managedCategoryName, setManagedCategoryName] = useState(undefined); + const classification = useClassificationMarkings(); + const isSystemAdmin = useSelector(isCurrentUserSystemAdmin); + const canManageClassification = classification.available && isSystemAdmin; + const [classificationEnabled, setClassificationEnabled] = useState(false); + const [selectedClassificationId, setSelectedClassificationId] = useState(''); + const [bannerText, setBannerText] = useState(''); + const [bannerTextPreview, setBannerTextPreview] = useState(false); + + const classificationOptions = useMemo(() => { + return classification.levels. + filter((l) => l.name.trim() !== ''). + map((l) => ({value: l.id, label: l.name.trim(), color: l.color})); + }, [classification.levels]); + + const selectedClassificationOption = useMemo((): ValueType | undefined => { + return classificationOptions.find((o) => o.value === selectedClassificationId); + }, [classificationOptions, selectedClassificationId]); + + const selectedClassificationLevel = useMemo(() => { + return classification.levels.find((l) => l.id === selectedClassificationId); + }, [classification.levels, selectedClassificationId]); + + const handleClassificationLevelChange = useCallback((selected: ValueType) => { + setSelectedClassificationId(selected.value); + const level = classification.levels.find((l) => l.id === selected.value); + if (level) { + setBannerText(`**${level.name}**`); + } + }, [classification.levels]); + + const formatClassificationOptionLabel = useCallback((option: ValueType) => { + const levelOption = option as ValueType & {color: string}; + return ( + + + {levelOption.label} + + ); + }, []); + // create a board along with the channel const createBoardFromChannelPlugin = useSelector((state: GlobalState) => state.plugins.components.CreateBoardFromTemplate); const newChannelWithBoardPulsatingDotState = useSelector((state: GlobalState) => getPreference(state, Preferences.APP_BAR, Preferences.NEW_CHANNEL_WITH_BOARD_TOUR_SHOWED, '')); @@ -117,6 +175,13 @@ const NewChannelModal = () => { update_at: 0, default_category_name: defaultCategoryName, managed_category_name: managedCategoryName, + ...(classificationEnabled && selectedClassificationId && bannerText ? { + banner_info: { + enabled: true, + text: bannerText, + background_color: selectedClassificationLevel?.color || '', + }, + } : {}), }; try { @@ -126,6 +191,19 @@ const NewChannelModal = () => { return; } + if (classificationEnabled && selectedClassificationId && classification.channelField && bannerText) { + try { + await Client4.patchPropertyValues( + CLASSIFICATIONS_GROUP_NAME, + CLASSIFICATIONS_CHANNEL_OBJECT_TYPE, + newChannel!.id, + [{field_id: classification.channelField.id, value: selectedClassificationId}], + ); + } catch { + // Classification save failure should not block channel creation + } + } + handleOnModalCancel(); // If template selected, create a new board from this template @@ -227,7 +305,8 @@ const NewChannelModal = () => { e.stopPropagation(); }; - const canCreate = displayName && !urlError && type && !purposeError && !serverError && canCreateFromPluggable && !channelInputError; + const classificationValid = !classificationEnabled || (Boolean(selectedClassificationId) && bannerText.trim().length > 0); + const canCreate = displayName && !urlError && type && !purposeError && !serverError && canCreateFromPluggable && !channelInputError && classificationValid; const newBoardInfoIcon = ( { /> }
+ {canManageClassification && ( +
+
+
+

+ +

+
+ setClassificationEnabled(!classificationEnabled)} + toggleClassName='btn-toggle-primary' + size='btn-md' + ariaLabel={formatMessage({id: 'channel_modal.classification.toggle_label', defaultMessage: 'Channel classification'})} + /> +
+

+ +

+ {classificationEnabled && ( +
+
+ + + +
+ +
+
+ {selectedClassificationLevel && ( +
+ + + +
+ {}} + useChannelMentions={false} + onChange={(e: React.ChangeEvent) => setBannerText(e.target.value)} + preview={bannerTextPreview} + togglePreview={() => setBannerTextPreview(!bannerTextPreview)} + createMessage={formatMessage({id: 'channel_modal.classification.banner_placeholder', defaultMessage: 'Banner text'})} + maxLength={CHANNEL_BANNER_MAX_CHARACTER_LIMIT} + minLength={CHANNEL_BANNER_MIN_CHARACTER_LIMIT} + /> +
+
+ )} +
+ )} +
+ )}
); diff --git a/webapp/channels/src/i18n/en.json b/webapp/channels/src/i18n/en.json index cc13c1b910c..a8450a2e1f4 100644 --- a/webapp/channels/src/i18n/en.json +++ b/webapp/channels/src/i18n/en.json @@ -4071,6 +4071,11 @@ "channel_menu.bookmarks.addLink": "Add a link", "channel_modal.alreadyExist": "A channel with that URL already exists", "channel_modal.cancel": "Cancel", + "channel_modal.classification.banner_label": "Banner text", + "channel_modal.classification.banner_placeholder": "Banner text", + "channel_modal.classification.level_label": "Classification level", + "channel_modal.classification.toggle_description": "When enabled, classification markings will appear for this channel. Individual channels cannot have a classification level lower than the global classification level.", + "channel_modal.classification.toggle_label": "Channel classification", "channel_modal.create_board.tooltip_description": "Use any of our templates to manage your tasks or start from scratch with your own!", "channel_modal.create_board.tooltip_title": "Manage your task with a board", "channel_modal.createNew": "Create channel", @@ -4159,6 +4164,9 @@ "channel_settings.archive.button": "Archive this channel", "channel_settings.archive.warning": "Archiving a channel removes it from the user interface, but doesn't permanently delete the channel. New messages can't be posted to archived channels.", "channel_settings.channel_info_tab.name": "Channel Info", + "channel_settings.classification.description": "When enabled, a classification level can be set for the channel with configurable indicators.", + "channel_settings.classification.level_label": "Classification level", + "channel_settings.classification.title": "Classification", "channel_settings.error_banner_color_required": "Banner color is required", "channel_settings.error_banner_text_required": "Banner text is required", "channel_settings.error_display_name_required": "Channel name is required", diff --git a/webapp/channels/src/sass/components/_color-input.scss b/webapp/channels/src/sass/components/_color-input.scss index 1d65c4015cb..b403e60625e 100644 --- a/webapp/channels/src/sass/components/_color-input.scss +++ b/webapp/channels/src/sass/components/_color-input.scss @@ -1,4 +1,5 @@ @use "utils/functions"; +@use "utils/variables"; .color-input { position: relative; @@ -30,11 +31,20 @@ border-radius: 0 4px 4px 0; background: functions.v(center-channel-bg) !important; line-height: 0; + + &--disabled { + background: rgba(var(--center-channel-color-rgb), 0.1) !important; + cursor: not-allowed; + + .color-icon { + cursor: not-allowed; + } + } } .color-popover { position: absolute; - z-index: 12; + z-index: variables.$z-index-popover; top: 100%; right: 0; padding-top: 8px; diff --git a/webapp/platform/client/src/client4.ts b/webapp/platform/client/src/client4.ts index 1c7d16a424a..d464978ddaa 100644 --- a/webapp/platform/client/src/client4.ts +++ b/webapp/platform/client/src/client4.ts @@ -2210,6 +2210,13 @@ export default class Client4 { ); }; + patchPropertyValues = (groupName: string, objectType: string, targetId: string, items: Array<{field_id: string; value: T}>) => { + return this.doFetch>>( + `${this.getBaseRoute()}/properties/groups/${groupName}/${objectType}/values/${targetId}`, + {method: 'PATCH', body: JSON.stringify(items)}, + ); + }; + // Remote Clusters Routes getRemoteClusters = (options: { From 1d1580cb3c62681346120a1689fa392952fbb899 Mon Sep 17 00:00:00 2001 From: sabril <5334504+saturninoabril@users.noreply.github.com> Date: Tue, 19 May 2026 10:57:53 +0800 Subject: [PATCH 28/80] chore: update reusable workflows to specific commit sha (#36600) --- .github/workflows/e2e-tests-cypress-template-v2.yml | 6 +++--- .github/workflows/e2e-tests-playwright-template-v2.yml | 6 +++--- .github/workflows/pr-test-analysis-override.yml | 3 +-- .github/workflows/pr-test-analysis.yml | 3 +-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/e2e-tests-cypress-template-v2.yml b/.github/workflows/e2e-tests-cypress-template-v2.yml index 3923c2c9e6a..16351715089 100644 --- a/.github/workflows/e2e-tests-cypress-template-v2.yml +++ b/.github/workflows/e2e-tests-cypress-template-v2.yml @@ -194,7 +194,7 @@ jobs: echo "workers=$(jq -nc --argjson n ${{ inputs.workers }} '[range(1; $n+1)]')" >> $GITHUB_OUTPUT echo "start_time=$(date +%s)" >> $GITHUB_OUTPUT - name: ci/dispatch-begin - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: cypress @@ -278,7 +278,7 @@ jobs: working-directory: e2e-tests/cypress run: npm ci - name: ci/dispatch-run - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: cypress @@ -321,7 +321,7 @@ jobs: - name: ci/run-summary id: summary continue-on-error: true - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} diff --git a/.github/workflows/e2e-tests-playwright-template-v2.yml b/.github/workflows/e2e-tests-playwright-template-v2.yml index 606332c36fc..1573e1d4616 100644 --- a/.github/workflows/e2e-tests-playwright-template-v2.yml +++ b/.github/workflows/e2e-tests-playwright-template-v2.yml @@ -159,7 +159,7 @@ jobs: echo "workers=$(jq -nc --argjson n ${{ inputs.workers }} '[range(1; $n+1)]')" >> $GITHUB_OUTPUT echo "start_time=$(date +%s)" >> $GITHUB_OUTPUT - name: ci/dispatch-begin - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: playwright @@ -234,7 +234,7 @@ jobs: npm run build npx playwright test --project=setup - name: ci/dispatch-run - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: playwright @@ -278,7 +278,7 @@ jobs: - name: ci/run-summary id: summary continue-on-error: true - uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@main + uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} diff --git a/.github/workflows/pr-test-analysis-override.yml b/.github/workflows/pr-test-analysis-override.yml index 9372742d3d5..65e16c0df74 100644 --- a/.github/workflows/pr-test-analysis-override.yml +++ b/.github/workflows/pr-test-analysis-override.yml @@ -21,8 +21,7 @@ jobs: github.event.issue.pull_request && startsWith(github.event.comment.body, '/test-analysis-override') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association) - # Pin to a commit SHA once the reusable workflow is stable. Using @main during initial rollout. - uses: mattermost/mattermost-test-automation-toolkit/.github/workflows/pr-test-analysis-override.yml@main + uses: mattermost/mattermost-test-automation-toolkit/.github/workflows/pr-test-analysis-override.yml@93d73f4f101e10dc03c7ed6b76b35eb5ff5babb7 # 2026-05-16 with: pr_number: ${{ github.event.issue.number }} target_repo: mattermost/mattermost diff --git a/.github/workflows/pr-test-analysis.yml b/.github/workflows/pr-test-analysis.yml index fc232212bec..03828d7dc87 100644 --- a/.github/workflows/pr-test-analysis.yml +++ b/.github/workflows/pr-test-analysis.yml @@ -36,8 +36,7 @@ jobs: github.event_name == 'workflow_dispatch' || (github.event.pull_request.draft == false && github.event.pull_request.head.repo.full_name == 'mattermost/mattermost') - # Pin to a commit SHA once the reusable workflow is stable. Using @main during initial rollout. - uses: mattermost/mattermost-test-automation-toolkit/.github/workflows/pr-test-analysis.yml@main + uses: mattermost/mattermost-test-automation-toolkit/.github/workflows/pr-test-analysis.yml@93d73f4f101e10dc03c7ed6b76b35eb5ff5babb7 # 2026-05-16 with: pr_number: ${{ github.event.pull_request.number || inputs.pr_number }} target_repo: mattermost/mattermost From 9d318dc4cdabe6c9fd76a286ba7c07515a991fec Mon Sep 17 00:00:00 2001 From: sabril <5334504+saturninoabril@users.noreply.github.com> Date: Tue, 19 May 2026 11:26:39 +0800 Subject: [PATCH 29/80] refactor: speed up E2E test workflows and eliminate npm cache-restore failures (#36599) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Workers no longer run `npm ci` — `node_modules` and framework binaries are restored from actions/cache populated once by a new `prep-deps` job. This closes the intermittent EEXIST/ENOENT failure inside npm's own cacache writer that occasionally fails `npm ci` on a runner. Removing `npm ci` from workers also cuts ~5 min of duplicated install work per worker. dispatch-begin now runs as its own job after prep-deps so it fires once the per-worker test-server setup is the only remaining work before dispatch-run. --- .github/e2e-tests-workflows.md | 112 +++++++++++-- .github/workflows/e2e-tests-check.yml | 6 + .../e2e-tests-cypress-template-v2.yml | 84 +++++++--- .../e2e-tests-playwright-template-v2.yml | 151 +++++++++++++++--- .../e2e-tests-playwright-template.yml | 9 +- 5 files changed, 314 insertions(+), 48 deletions(-) diff --git a/.github/e2e-tests-workflows.md b/.github/e2e-tests-workflows.md index b6c6a6577fe..95660924dfd 100644 --- a/.github/e2e-tests-workflows.md +++ b/.github/e2e-tests-workflows.md @@ -16,25 +16,119 @@ All pipelines follow the **smoke-then-full** pattern: smoke tests run first, ful ``` .github/workflows/ -├── e2e-tests-ci.yml # PR orchestrator -├── e2e-tests-on-merge.yml # Merge orchestrator (master/release branches) -├── e2e-tests-on-release.yml # Release cut orchestrator -├── e2e-tests-cypress.yml # Shared wrapper: cypress smoke -> full -├── e2e-tests-playwright.yml # Shared wrapper: playwright smoke -> full -├── e2e-tests-cypress-template.yml # Template: actual cypress test execution -└── e2e-tests-playwright-template.yml # Template: actual playwright test execution +├── e2e-tests-ci.yml # PR orchestrator +├── e2e-tests-on-merge.yml # Merge orchestrator (master/release branches) +├── e2e-tests-on-release.yml # Release cut orchestrator +├── e2e-tests-cypress.yml # Shared wrapper: routes to v1 or v2 template +├── e2e-tests-playwright.yml # Shared wrapper: routes to v1 or v2 template +├── e2e-tests-cypress-template-v2.yml # Active: cypress + test-system-io dispatch +├── e2e-tests-playwright-template-v2.yml # Active: playwright + test-system-io dispatch +├── e2e-tests-cypress-template.yml # Deprecated v1 (legacy in-job execution) +└── e2e-tests-playwright-template.yml # Deprecated v1 (legacy in-job execution) ``` +> **v1 templates are deprecated.** They remain available behind a feature flag during cutover but receive no further changes. New work targets the v2 templates exclusively. The wrappers route by `vars.E2E_USE_TEST_IO_DISPATCH` — `'true'` selects v2, anything else falls back to v1. + ### Call hierarchy ``` e2e-tests-ci.yml ─────────────────┐ -e2e-tests-on-merge.yml ───────────┤──► e2e-tests-cypress.yml ──► e2e-tests-cypress-template.yml -e2e-tests-on-release.yml ─────────┘ e2e-tests-playwright.yml ──► e2e-tests-playwright-template.yml +e2e-tests-on-merge.yml ───────────┤──► e2e-tests-cypress.yml ─────┐ +e2e-tests-on-release.yml ─────────┘ e2e-tests-playwright.yml ──┤ + │ + ┌──────────────────────────┘ + │ routes on E2E_USE_TEST_IO_DISPATCH + ▼ + v2 (active) ──► e2e-tests-{cypress,playwright}-template-v2.yml + v1 (legacy) ──► e2e-tests-{cypress,playwright}-template.yml ``` --- +## Workflow Architecture (v2) + +v2 splits the template into five jobs — `prepare-run`, `prep-deps`, `dispatch-begin`, `workers` (matrix), and `report` — and pushes spec-level execution to [Test System IO](https://github.com/mattermost/mattermost-test-system-io) so workers stay thin and identical. + +``` +┌──────────────────────────────────────────────────────────────────────────┐ +│ Template v2: e2e-tests-{cypress,playwright}-template-v2.yml │ +│ │ +│ ┌───────────────────┐ ┌──────────────────────────────┐ │ +│ │ prepare-run │ │ prep-deps │ │ +│ │ (1 runner) │ parallel │ (1 runner) │ │ +│ │ │ ◄────────────► │ │ │ +│ │ • build workers │ │ Cypress: │ │ +│ │ matrix [1..N] │ │ • cypress/node_modules │ │ +│ │ • compute commit │ │ • ~/.cache/Cypress (binary)│ │ +│ │ status context │ │ │ │ +│ │ • emit composite │ │ Playwright: │ │ +│ │ identity │ │ • webapp/platform/{client, │ │ +│ │ │ │ types}/{lib,node_mod} │ │ +│ │ │ │ • playwright/node_modules │ │ +│ │ │ │ • playwright/lib/dist │ │ +│ │ │ │ • ~/.cache/ms-playwright │ │ +│ │ │ │ (chromium only) │ │ +│ │ │ │ │ │ +│ │ │ │ → saved to actions/cache │ │ +│ └─────────┬─────────┘ └───────────────┬──────────────┘ │ +│ │ │ │ +│ │ ▼ │ +│ │ ┌──────────────────────────────┐ │ +│ │ │ dispatch-begin │ │ +│ │ │ • register run with │ │ +│ │ │ Test System IO │ │ +│ │ │ • runs immediately before │ │ +│ │ │ workers to minimise the │ │ +│ │ │ inactivity-timeout window │ │ +│ │ └───────────────┬──────────────┘ │ +│ │ │ │ +│ └────────────────────┬─────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ workers (matrix, fail-fast: false) │ │ +│ │ Cypress full: N=40 | Playwright full: N=10 │ │ +│ │ │ │ +│ │ each worker, in parallel: │ │ +│ │ 1. sparse-checkout actions + full checkout-repo │ │ +│ │ 2. setup-node │ │ +│ │ 3. restore caches ◄─── actions/cache (from prep-deps) │ │ +│ │ (fail-on-cache-miss: true) │ │ +│ │ 4. cloud-init + start-server (docker compose stack) │ │ +│ │ 5. prepare-cypress | prepare-playwright (run setup project) │ │ +│ │ 6. dispatch-run ──────────────────────────────────┐ │ │ +│ │ (pulls specs from Test System IO, runs locally, │ │ │ +│ │ posts result, loops until queue is empty) │ │ │ +│ │ 7. cloud-teardown │ │ │ +│ └────────────────────┬────────────────────────────────────┼───────┘ │ +│ │ │ │ +│ ▼ │ │ +│ ┌─────────────────────────────────────────────────┐ │ │ +│ │ report │ │ │ +│ │ • pull aggregated results from Test System IO │ ◄────┘ │ +│ │ • post commit status │ │ +│ │ • send webhook notification │ │ +│ └─────────────────────────────────────────────────┘ │ +└────────────────────────────────────────────────────────────────────────┬─┘ + │ + ┌───────────────────────────▼───┐ + │ Test System IO (external) │ + │ • spec-level dispatch │ + │ • result aggregation │ + │ • retry orchestration │ + └───────────────────────────────┘ +``` + +### Key properties + +- **Spec-level vs. job-level parallelism.** The matrix sizes the runner pool; Test System IO does the spec assignment. Slow specs don't block a worker — fast workers keep pulling the next spec from the queue. +- **Cache-only workers.** `prep-deps` installs once per workflow run and saves to `actions/cache`. Every worker restores with `fail-on-cache-miss: true` and runs zero `npm ci`. Eliminates the 40-way `EEXIST/ENOENT` race in npm's shared cacache writer. +- **dispatch-begin runs late.** It depends on `prep-deps` so the gap between Test System IO run registration and the first worker calling `dispatch-run` is just per-worker setup (~3–5 min). Registering earlier risks the run timing out before any worker checks in, bulk-failing every spec. +- **Playwright slim slice.** Playwright only consumes `@mattermost/client` and `@mattermost/types` from webapp, so prep-deps caches just those two packages' built `lib/` and `node_modules` (~10–30 MB) instead of the full `webapp/node_modules` tree (~1–2 GB). +- **Browser/binary caches.** Cypress caches `~/.cache/Cypress` (cypress binary lives outside node_modules); playwright caches `~/.cache/ms-playwright` (chromium only). Both keyed on the framework's lockfile so they invalidate on version bumps. +- **No retry plumbing in the template.** Test System IO handles per-spec retries; the workflow only sees aggregated results. + +--- + ## Pipeline 1: PR (`e2e-tests-ci.yml`) Runs E2E tests for every PR commit after the enterprise docker image is built. Fails if the commit is not associated with an open PR. diff --git a/.github/workflows/e2e-tests-check.yml b/.github/workflows/e2e-tests-check.yml index 53b59b726f7..13cd9b45fca 100644 --- a/.github/workflows/e2e-tests-check.yml +++ b/.github/workflows/e2e-tests-check.yml @@ -25,6 +25,12 @@ jobs: cache-dependency-path: | e2e-tests/cypress/package-lock.json e2e-tests/playwright/package-lock.json + webapp/package-lock.json + - name: ci/npm-cache-verify + # Heal any partial/dangling entries left in the restored ~/.npm cache + # before running `npm ci`. Avoids the intermittent EEXIST/ENOENT + # failures in npm's cacache writer. + run: npm cache verify # Cypress check - name: ci/cypress/npm-install diff --git a/.github/workflows/e2e-tests-cypress-template-v2.yml b/.github/workflows/e2e-tests-cypress-template-v2.yml index 16351715089..f08b7ec9f32 100644 --- a/.github/workflows/e2e-tests-cypress-template-v2.yml +++ b/.github/workflows/e2e-tests-cypress-template-v2.yml @@ -135,7 +135,7 @@ env: SERVER_IMAGE: "${{ inputs.server_image_repo }}/${{ inputs.server_edition == 'fips' && 'mattermost-enterprise-fips-edition' || inputs.server_edition == 'team' && 'mattermost-team-edition' || 'mattermost-enterprise-edition' }}:${{ inputs.server_image_tag }}" jobs: - dispatch-begin: + prepare-run: runs-on: ubuntu-24.04 permissions: contents: read @@ -193,13 +193,62 @@ jobs: run: | echo "workers=$(jq -nc --argjson n ${{ inputs.workers }} '[range(1; $n+1)]')" >> $GITHUB_OUTPUT echo "start_time=$(date +%s)" >> $GITHUB_OUTPUT + + # Install cypress node_modules once, then workers restore from cache. + prep-deps: + name: prep-deps + runs-on: ubuntu-24.04 + timeout-minutes: 10 + permissions: + contents: read + steps: + - name: ci/checkout-repo + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: ${{ inputs.commit_sha }} + fetch-depth: 1 + - name: ci/setup-node + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + with: + node-version-file: ".nvmrc" + - name: ci/cache-cypress-deps + # node_modules + the cypress binary (downloaded to ~/.cache/Cypress by + # cypress's postinstall, not into node_modules). Both must be cached; + # otherwise workers see "cypress npm package installed but binary missing". + id: cache-cypress + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + e2e-tests/cypress/node_modules + ~/.cache/Cypress + key: e2e-cypress-deps-${{ runner.os }}-${{ hashFiles('e2e-tests/cypress/package-lock.json') }} + - name: ci/install-cypress-deps + if: steps.cache-cypress.outputs.cache-hit != 'true' + working-directory: e2e-tests/cypress + run: npm ci + + # Register the Test System IO run AFTER prep-deps so workers reach + # dispatch-run within Test System IO's inactivity window. + dispatch-begin: + runs-on: ubuntu-24.04 + needs: [prepare-run, prep-deps] + permissions: + contents: read + id-token: write + statuses: write + steps: + - name: ci/checkout-repo + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: ${{ inputs.commit_sha }} + fetch-depth: 1 - name: ci/dispatch-begin uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: cypress repo-dir: ${{ github.workspace }} - composite-identity: ${{ steps.composite-identity.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} total-reports-expected: ${{ inputs.workers }} retest-on-fail: ${{ inputs.retest_on_fail }} cypress-stage: ${{ inputs.cypress_stage }} @@ -217,16 +266,16 @@ jobs: name: dispatch-run-${{ matrix.worker_index }} runs-on: ubuntu-24.04 timeout-minutes: 30 - needs: dispatch-begin + needs: [prepare-run, dispatch-begin] permissions: contents: read id-token: write strategy: fail-fast: false matrix: - worker_index: ${{ fromJSON(needs.dispatch-begin.outputs.workers-matrix) }} + worker_index: ${{ fromJSON(needs.prepare-run.outputs.workers-matrix) }} env: - COMPOSITE_IDENTITY: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + COMPOSITE_IDENTITY: ${{ needs.prepare-run.outputs.composite-identity-json }} SERVER: "${{ inputs.server }}" MM_LICENSE: "${{ secrets.MM_LICENSE }}" ENABLED_DOCKER_SERVICES: "${{ inputs.enabled_docker_services }}" @@ -260,29 +309,26 @@ jobs: uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version-file: ".nvmrc" - cache: npm - cache-dependency-path: "e2e-tests/cypress/package-lock.json" - - name: ci/get-webapp-node-modules - working-directory: webapp - run: make node_modules + - name: ci/restore-cypress-deps + uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + e2e-tests/cypress/node_modules + ~/.cache/Cypress + key: e2e-cypress-deps-${{ runner.os }}-${{ hashFiles('e2e-tests/cypress/package-lock.json') }} + fail-on-cache-miss: true - name: ci/cloud-init working-directory: e2e-tests run: make cloud-init - name: ci/start-server working-directory: e2e-tests run: make start-server - # `npm ci` in the host context replaces the container-built native - # binaries with host-built ones, since the dispatch adapter spawns - # `npx cypress run` on the host. - - name: ci/prepare-cypress - working-directory: e2e-tests/cypress - run: npm ci - name: ci/dispatch-run uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: cypress - composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} repo-dir: ${{ github.workspace }} artifacts-root: ${{ github.workspace }}/worker-artifacts github-token: ${{ secrets.GITHUB_TOKEN }} @@ -306,7 +352,7 @@ jobs: report: runs-on: ubuntu-24.04 - needs: [dispatch-begin, workers] + needs: [prepare-run, dispatch-begin, workers] if: always() permissions: contents: read @@ -324,7 +370,7 @@ jobs: uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} - composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} framework: cypress report-type: ${{ inputs.report_type }} image-tag: ${{ inputs.server_image_tag }} diff --git a/.github/workflows/e2e-tests-playwright-template-v2.yml b/.github/workflows/e2e-tests-playwright-template-v2.yml index 1573e1d4616..d2d1966209b 100644 --- a/.github/workflows/e2e-tests-playwright-template-v2.yml +++ b/.github/workflows/e2e-tests-playwright-template-v2.yml @@ -101,7 +101,7 @@ env: SERVER_IMAGE: "${{ inputs.server_image_repo }}/${{ inputs.server_edition == 'fips' && 'mattermost-enterprise-fips-edition' || inputs.server_edition == 'team' && 'mattermost-team-edition' || 'mattermost-enterprise-edition' }}:${{ inputs.server_image_tag }}" jobs: - dispatch-begin: + prepare-run: runs-on: ubuntu-24.04 permissions: contents: read @@ -158,13 +158,104 @@ jobs: run: | echo "workers=$(jq -nc --argjson n ${{ inputs.workers }} '[range(1; $n+1)]')" >> $GITHUB_OUTPUT echo "start_time=$(date +%s)" >> $GITHUB_OUTPUT + + # Build @mattermost/client + @mattermost/types and install playwright deps once, + # then workers restore from cache. Playwright only consumes those two packages + # from webapp, so we cache just their built lib/ instead of all of webapp/node_modules. + prep-deps: + name: prep-deps + runs-on: ubuntu-24.04 + timeout-minutes: 15 + permissions: + contents: read + steps: + - name: ci/checkout-repo + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: ${{ inputs.commit_sha }} + fetch-depth: 1 + - name: ci/setup-node + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + with: + node-version-file: ".nvmrc" + - name: ci/cache-platform-pkgs + # `webapp/node_modules/@mattermost/{client,types}` are the workspace + # symlinks Node walks up to find when platform/client requires + # @mattermost/types. Without them, module resolution fails inside + # the slim slice. + id: cache-platform-pkgs + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + webapp/node_modules/@mattermost/client + webapp/node_modules/@mattermost/types + webapp/platform/client/lib + webapp/platform/client/node_modules + webapp/platform/types/lib + webapp/platform/types/node_modules + key: e2e-platform-pkgs-${{ runner.os }}-${{ hashFiles('webapp/package-lock.json', 'webapp/platform/client/src/**', 'webapp/platform/client/tsconfig*.json', 'webapp/platform/types/src/**', 'webapp/platform/types/tsconfig*.json') }} + - name: ci/build-platform-pkgs + # Full webapp install is needed for tsc + workspace linking; the + # postinstall builds platform/{client,types}/lib. We only cache those. + if: steps.cache-platform-pkgs.outputs.cache-hit != 'true' + working-directory: webapp + run: make node_modules + - name: ci/cache-playwright-deps + # Caches node_modules + the rolled-up @mattermost/playwright-lib dist + # so workers don't re-run rollup on every job. + id: cache-playwright + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + e2e-tests/playwright/node_modules + e2e-tests/playwright/lib/dist + e2e-tests/playwright/lib/node_modules + key: e2e-playwright-deps-${{ runner.os }}-${{ hashFiles('e2e-tests/playwright/package-lock.json', 'e2e-tests/playwright/lib/src/**', 'e2e-tests/playwright/lib/package.json', 'e2e-tests/playwright/lib/rollup.config.js', 'e2e-tests/playwright/lib/tsconfig.json') }} + - name: ci/install-playwright-deps + # `npm ci` creates symlinks at node_modules/@mattermost/{client,types} + # → webapp/platform/{client,types}; targets must already be built. + # The postinstall then builds lib/dist via rollup. Skip browser + # download here — chromium is cached separately below. + if: steps.cache-playwright.outputs.cache-hit != 'true' + working-directory: e2e-tests/playwright + env: + PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: "1" + run: npm ci + - name: ci/cache-playwright-browsers + # Cache chromium binary (~150MB) keyed on the playwright lockfile so a + # version bump invalidates. Restored by workers; no docker image needed. + id: cache-pw-browsers + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ~/.cache/ms-playwright + key: playwright-browsers-${{ runner.os }}-${{ hashFiles('e2e-tests/playwright/package-lock.json') }} + - name: ci/install-playwright-chromium + if: steps.cache-pw-browsers.outputs.cache-hit != 'true' + working-directory: e2e-tests/playwright + run: npx playwright install chromium + + # Register the Test System IO run AFTER prep-deps so workers reach + # dispatch-run within Test System IO's inactivity window. + dispatch-begin: + runs-on: ubuntu-24.04 + needs: [prepare-run, prep-deps] + permissions: + contents: read + id-token: write + statuses: write + steps: + - name: ci/checkout-repo + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: ${{ inputs.commit_sha }} + fetch-depth: 1 - name: ci/dispatch-begin uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-begin@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: playwright repo-dir: ${{ github.workspace }} - composite-identity: ${{ steps.composite-identity.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} total-reports-expected: ${{ inputs.workers }} retest-on-fail: ${{ inputs.retest_on_fail }} playwright-project: ${{ inputs.playwright_project }} @@ -177,16 +268,16 @@ jobs: name: dispatch-run-${{ matrix.worker_index }} runs-on: ubuntu-24.04 timeout-minutes: 30 - needs: dispatch-begin + needs: [prepare-run, dispatch-begin] permissions: contents: read id-token: write strategy: fail-fast: false matrix: - worker_index: ${{ fromJSON(needs.dispatch-begin.outputs.workers-matrix) }} + worker_index: ${{ fromJSON(needs.prepare-run.outputs.workers-matrix) }} env: - COMPOSITE_IDENTITY: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + COMPOSITE_IDENTITY: ${{ needs.prepare-run.outputs.composite-identity-json }} SERVER: "${{ inputs.server }}" MM_LICENSE: "${{ secrets.MM_LICENSE }}" ENABLED_DOCKER_SERVICES: "${{ inputs.enabled_docker_services }}" @@ -214,31 +305,53 @@ jobs: uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version-file: ".nvmrc" - cache: npm - cache-dependency-path: "e2e-tests/playwright/package-lock.json" - - name: ci/get-webapp-node-modules - working-directory: webapp - run: make node_modules + - name: ci/restore-platform-pkgs + # Built lib/ for @mattermost/client and @mattermost/types, plus the + # webapp workspace symlinks under webapp/node_modules/@mattermost/. + uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + webapp/node_modules/@mattermost/client + webapp/node_modules/@mattermost/types + webapp/platform/client/lib + webapp/platform/client/node_modules + webapp/platform/types/lib + webapp/platform/types/node_modules + key: e2e-platform-pkgs-${{ runner.os }}-${{ hashFiles('webapp/package-lock.json', 'webapp/platform/client/src/**', 'webapp/platform/client/tsconfig*.json', 'webapp/platform/types/src/**', 'webapp/platform/types/tsconfig*.json') }} + fail-on-cache-miss: true + - name: ci/restore-playwright-deps + uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: | + e2e-tests/playwright/node_modules + e2e-tests/playwright/lib/dist + e2e-tests/playwright/lib/node_modules + key: e2e-playwright-deps-${{ runner.os }}-${{ hashFiles('e2e-tests/playwright/package-lock.json', 'e2e-tests/playwright/lib/src/**', 'e2e-tests/playwright/lib/package.json', 'e2e-tests/playwright/lib/rollup.config.js', 'e2e-tests/playwright/lib/tsconfig.json') }} + fail-on-cache-miss: true + - name: ci/restore-playwright-browsers + uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ~/.cache/ms-playwright + key: playwright-browsers-${{ runner.os }}-${{ hashFiles('e2e-tests/playwright/package-lock.json') }} + fail-on-cache-miss: true - name: ci/cloud-init working-directory: e2e-tests run: make cloud-init - name: ci/start-server working-directory: e2e-tests run: make start-server - # Build once + run the `setup` project so per-spec dispatches can - # pass --no-deps and skip plugin-load + server-deployment checks. + # Run the `setup` project so per-spec dispatches can pass --no-deps + # and skip plugin-load + server-deployment checks. node_modules, + # lib/dist, and chromium are all restored from cache. - name: ci/prepare-playwright working-directory: e2e-tests/playwright - run: | - npm ci - npm run build - npx playwright test --project=setup + run: npx playwright test --project=setup - name: ci/dispatch-run uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-dispatch-run@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} framework: playwright - composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} repo-dir: ${{ github.workspace }} artifacts-root: ${{ github.workspace }}/worker-artifacts github-token: ${{ secrets.GITHUB_TOKEN }} @@ -263,7 +376,7 @@ jobs: report: runs-on: ubuntu-24.04 - needs: [dispatch-begin, workers] + needs: [prepare-run, dispatch-begin, workers] if: always() permissions: contents: read @@ -281,7 +394,7 @@ jobs: uses: mattermost/mattermost-test-system-io/.github/actions/test-system-io-summary@a2ea7f005484c28fedf51e16645f6d3bd683fd63 # 2026-05-16 with: use-staging: ${{ vars.E2E_USE_STAGING_TEST_IO_URL != 'false' }} - composite-identity: ${{ needs.dispatch-begin.outputs.composite-identity-json }} + composite-identity: ${{ needs.prepare-run.outputs.composite-identity-json }} framework: playwright report-type: ${{ inputs.report_type }} image-tag: ${{ inputs.server_image_tag }} diff --git a/.github/workflows/e2e-tests-playwright-template.yml b/.github/workflows/e2e-tests-playwright-template.yml index af976a3110b..80d92d71794 100644 --- a/.github/workflows/e2e-tests-playwright-template.yml +++ b/.github/workflows/e2e-tests-playwright-template.yml @@ -179,7 +179,14 @@ jobs: with: node-version-file: ".nvmrc" cache: npm - cache-dependency-path: "e2e-tests/playwright/package-lock.json" + cache-dependency-path: | + e2e-tests/playwright/package-lock.json + webapp/package-lock.json + - name: ci/npm-cache-verify + # Heal any partial/dangling entries left in the restored ~/.npm cache + # before running `npm ci`. Avoids the intermittent EEXIST/ENOENT + # failures in npm's cacache writer. + run: npm cache verify - name: ci/get-webapp-node-modules working-directory: webapp run: make node_modules From 0675d0ea0b38e29976af4774041216d182f7e9fc Mon Sep 17 00:00:00 2001 From: Amy Blais <29708087+amyblais@users.noreply.github.com> Date: Tue, 19 May 2026 12:04:25 +0300 Subject: [PATCH 30/80] Automations for config.json, API, audit log event, and Go release notes (#36075) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Create config-change-checker.yml * Create check_config_changes_ci.py * Update config-change-checker.yml * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Update config-change-checker.yml * Update check_config_changes_ci.py * Update config-change-checker.yml * Update config.go * Fix check_api to detect multi-line and multi-method endpoints The previous implementation matched the .Handle(...).Methods(...) regex line-by-line against diff lines. This silently missed two real and common patterns in api4/: 1. Multi-line .Handle(...) declarations — e.g. group.go has 18 of them, where the path lives on one line and the wrapper/handler on the next. The regex never matched, so PRs adding such endpoints produced empty release-note entries. 2. Multi-method declarations like .Methods(http.MethodGet, http.MethodHead) (4 instances in file.go) — the old regex required a closing paren immediately after the first method. The fix: - Add a file_at(ref, path) helper that snapshots a file at a git ref via 'git show', so checkers can compare full file states instead of pattern-matching diff text. - Add _scan_endpoints() that whitespace-collapses the file before matching, letting the regex span what were originally multiple lines. - Loosen _HANDLE_RE to capture the methods list as a substring and extract individual HTTP verbs with a known-method allowlist, so multi-method declarations produce one entry per verb. - Switch check_api to set-diff (after - before) / (before - after) on the parsed endpoint sets. This also cleanly handles routes that move within a file (no fragile add/remove dedup needed). - Anchor the new/deleted file detection to '^new file mode \d+' to avoid false positives from stray text in source files. Made-with: Cursor * Track enclosing struct in check_config to avoid dedup collisions The previous check_config keyed its add/remove dedup on the bare field name. The dedup intent was to ignore fields that were merely reordered within config.go (which appear in the diff as both '-Foo' and '+Foo'). But because the key was just the field name, an unrelated rename in one struct could silently cancel out a real new field with the same name in a different struct. For example, in a single PR: - EnableFoo *bool // removed from ServiceSettings + EnableFooV2 *bool - EnableBar *bool // removed from EmailSettings + EnableFoo *bool // newly added — but wrongly cancelled below The dedup would see 'EnableFoo' in both lists and drop both entries, hiding the brand-new EmailSettings.EnableFoo from the release-note output. The fix tracks each field's enclosing struct using a brace-depth stack that walks the file at BASE_SHA and HEAD_SHA. Fields are keyed as (struct_name, field_name) tuples, so identically-named fields in different structs are distinct, and the dedup only collapses true reorderings. As a side benefit the rendered output is now 'StructName.FieldName' which is much more useful to reviewers. Switching to file-at-revision scanning + set diff also removes the custom dedup logic entirely — set arithmetic handles "moved within file" naturally. Made-with: Cursor * Switch remaining checkers to file-at-revision style; drop lines_by_sign check_audit_events and check_go_version still parsed +/- diff lines directly, with the same brittle dedup-and-cancel logic that was used in the previous check_config. After the previous two commits the rest of the file uses the file_at(ref, path) helper to compare full file states between BASE_SHA and HEAD_SHA, which: - removes the entire moved-within-file dedup dance (set arithmetic handles it for free), - aligns all four checkers on a single, easy-to-reason-about pattern, - is robust to whitespace-only or reordering edits in the watched files. For Dockerfile.buildenv the helper also avoids a subtle case where the old code only inspected +/- lines: an edit to an unrelated RUN line that didn't touch the FROM line could in theory leave both old_ver and new_ver as None even though the version was effectively unchanged. Reading the file at each revision compares the actual current and previous FROM line directly. The lines_by_sign helper now has no callers, so remove it. Made-with: Cursor * Update config.go * Update config.go * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Update check_config_changes_ci.py * Tighten check_config_changes_ci.py: regex coverage + idempotency - Restore tolerant `_HANDLE_RE` so 2-arg wrappers (e.g. `api.APISessionRequired(handler, handlerParamFileAPI)`) are not silently dropped from the api4 endpoint scan; broaden the `.Methods(...)` capture so string-literal variants (`Methods("GET")`) work too. Filtering moves back to the `_HTTP_METHODS` allowlist in `_parse_methods` to keep stray identifiers from being treated as HTTP verbs. - Make `strip_old_note` also remove auto-generated lines that landed outside the ```release-note fence (the inject_note fallback paths) so reruns no longer accumulate duplicates when a PR has no fence. - Skip the GitHub PATCH when the PR description is already up to date, so every commit no longer triggers an unconditional write. - Wire up `check_go_version`'s `additions` path in `_format_lines` and `_AUTO_LINE_RE` so a freshly-added Dockerfile.buildenv emits a note. - Remove the now-dead `CheckResult.to_markdown` method (replaced by `_format_lines`). Made-with: Cursor * Restore ExperimentalSettings.EnableWatermark The field was removed in f71527f0b1 but `server/config/client.go`, `server/config/client_test.go`, and `server/public/model/config_test.go` still reference it (added on master in #36025). Restoring the field makes the branch compile again so CI can go green. Made-with: Cursor * Replace placeholder release-note content (NONE / N/A) on injection The script previously appended its auto-detected lines INSIDE the ```release-note fence but never displaced template placeholders, so PRs that only had `NONE` ended up with output like: NONE Added `Foo.Bar` configuration setting. Go runtime updated from 1.25.8 to 1.25.9. When the existing fence content is empty or consists only of placeholder tokens (NONE, N/A, NA, dashes — case-insensitive), replace it entirely with the auto-detected entries. User-written human content is still preserved by appending instead. Idempotent: stripping followed by re-injection keeps the placeholder visible when there's nothing to inject, and replaces it again when there is. Made-with: Cursor * Update config-change-checker.yml * Update check_config_changes_ci.py --------- Co-authored-by: Your Name Co-authored-by: Mattermost Build --- .github/scripts/check_config_changes_ci.py | 590 ++++++++++++++++++++ .github/workflows/config-change-checker.yml | 66 +++ 2 files changed, 656 insertions(+) create mode 100644 .github/scripts/check_config_changes_ci.py create mode 100644 .github/workflows/config-change-checker.yml diff --git a/.github/scripts/check_config_changes_ci.py b/.github/scripts/check_config_changes_ci.py new file mode 100644 index 00000000000..be00bc7fef8 --- /dev/null +++ b/.github/scripts/check_config_changes_ci.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python3 +""" +.github/scripts/check_config_changes_ci.py + +CI script that detects notable changes across several Mattermost source files +and appends structured release-note entries to the PR description. + +Checkers +──────── +1. config.go — exported struct field additions/removals +2. api4/ — API endpoint additions/removals (Handle() calls) +3. audit_events.go — AuditEvent* constant additions/removals +4. Dockerfile.buildenv — Go (base-image) version changes + +All inputs come from environment variables set by the GitHub Actions workflow: + GITHUB_TOKEN — built-in Actions token (pull-requests: write scope) + PR_NUMBER — pull request number + BASE_SHA — base commit SHA + HEAD_SHA — head commit SHA + REPO — owner/repo (e.g. mattermost/mattermost) +""" + +import os +import re +import sys +import subprocess +import requests +from dataclasses import dataclass, field +from typing import Optional + +# ── Environment ──────────────────────────────────────────────────────────────── + +GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] +PR_NUMBER = int(os.environ["PR_NUMBER"]) +BASE_SHA = os.environ["BASE_SHA"] +HEAD_SHA = os.environ["HEAD_SHA"] +REPO = os.environ.get("REPO", "mattermost/mattermost") + +BASE_URL = "https://api.github.com" +HEADERS = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + +# Timeout for all GitHub API requests: (connect seconds, read seconds). +# Prevents the workflow from hanging indefinitely on a slow/unresponsive API. +_TIMEOUT = (5, 30) + +# Paths watched by this script (must align with `paths:` in the workflow YAML) +WATCHED_PATHS = [ + "server/public/model/config.go", + "server/channels/api4/", + "server/public/model/audit_events.go", + "server/build/Dockerfile.buildenv", +] + + +# ── Data types ───────────────────────────────────────────────────────────────── + +@dataclass +class CheckResult: + """Holds the findings from one checker.""" + label: str # Section heading, e.g. "`config.json` Changes" + additions: list = field(default_factory=list) + removals: list = field(default_factory=list) + changes: list = field(default_factory=list) # for free-form entries (version bumps) + + def has_findings(self) -> bool: + return bool(self.additions or self.removals or self.changes) + + def to_markdown(self) -> str: + lines = [f"### {self.label}"] + if self.additions: + lines.append("**Added:** " + ", ".join(self.additions)) + if self.removals: + lines.append("**Removed:** " + ", ".join(self.removals)) + for change in self.changes: + lines.append(change) + return "\n".join(lines) + + +# ── Diff helpers ─────────────────────────────────────────────────────────────── + +def get_full_patch() -> str: + """Return unified diff for all watched paths between base and head.""" + result = subprocess.run( + ["git", "diff", f"{BASE_SHA}...{HEAD_SHA}", "--"] + WATCHED_PATHS, + capture_output=True, + text=True, + check=True, + ) + return result.stdout + + +def split_patch_by_file(full_patch: str) -> dict[str, str]: + """ + Split a multi-file unified diff into {filename: patch} mapping. + Filenames are the b-side (new) path, stripped of the 'b/' prefix. + """ + patches: dict[str, str] = {} + current_file: Optional[str] = None + current_lines: list[str] = [] + + for line in full_patch.splitlines(keepends=True): + if line.startswith("diff --git "): + if current_file: + patches[current_file] = "".join(current_lines) + current_lines = [line] + # Extract filename from "diff --git a/foo b/foo" + m = re.search(r" b/(.+)$", line.rstrip()) + current_file = m.group(1) if m else None + else: + current_lines.append(line) + + if current_file: + patches[current_file] = "".join(current_lines) + + return patches + + +def file_at(ref: str, path: str) -> str: + """Return the full contents of `path` at git ref `ref`, or '' if absent.""" + try: + return subprocess.run( + ["git", "show", f"{ref}:{path}"], + capture_output=True, text=True, check=True, + ).stdout + except subprocess.CalledProcessError: + return "" + + +# ── Checker 1 — config.go ────────────────────────────────────────────────────── + +_CONFIG_PATH = "server/public/model/config.go" +_STRUCT_DECL_RE = re.compile(r"^type\s+(\w+)\s+struct\s*\{") +_FIELD_LINE_RE = re.compile(r"^\t([A-Z][A-Za-z0-9_]*)\s+\S") + + +def _scan_struct_fields(src: str) -> set[tuple[str, str]]: + """ + Walk Go source and return {(StructName, FieldName)} for every exported + field in every struct. + + Uses a brace-depth stack so nested anonymous structs, interface bodies, + and function literals don't corrupt the enclosing struct context. + Named type declarations cannot be nested in Go, so the struct_stack + never grows beyond one entry for named structs. + """ + fields: set[tuple[str, str]] = set() + # Each entry: (struct_name, brace_depth_when_opened) + struct_stack: list[tuple[str, int]] = [] + depth = 0 + + for line in src.splitlines(): + sm = _STRUCT_DECL_RE.match(line) + if sm: + # Record depth *before* counting this line's braces + struct_stack.append((sm.group(1), depth)) + + depth += line.count("{") - line.count("}") + + # Pop any structs whose closing brace has been passed + while struct_stack and depth <= struct_stack[-1][1]: + struct_stack.pop() + + # Record fields only when we're directly inside exactly one named struct + if len(struct_stack) == 1: + fm = _FIELD_LINE_RE.match(line) + if fm: + fields.add((struct_stack[0][0], fm.group(1))) + + return fields + + +def check_config(patches: dict[str, str]) -> CheckResult: + """ + Detect exported Go struct field additions/removals in config.go. + + Compares full-file snapshots at BASE_SHA and HEAD_SHA so that fields + are always attributed to the correct struct regardless of which diff + hunks are present. + """ + result = CheckResult(label="`config.json` Field Changes") + if _CONFIG_PATH not in patches: + return result + + base_fields = _scan_struct_fields(file_at(BASE_SHA, _CONFIG_PATH)) + head_fields = _scan_struct_fields(file_at(HEAD_SHA, _CONFIG_PATH)) + + added = head_fields - base_fields + removed = base_fields - head_fields + + result.additions = sorted(f"`{s}.{f}`" for s, f in added) + result.removals = sorted(f"`{s}.{f}`" for s, f in removed) + return result + + +# ── Checker 2 — api4/ ───────────────────────────────────────────────────────── + +# Matches Handle() route registrations after whitespace-collapsing the source. +# Whitespace collapse makes multi-line declarations single-searchable. +# Group 1: path Group 2: handler func Group 3: raw Methods(...) content +# +# The wrapper pattern uses [^)]* so it tolerates any middleware arguments +# (e.g. r.APIHandler(...), r.ApiSessionRequired(..., isLocal=true), etc.) +# without having to enumerate every possible wrapper signature. +_HANDLE_RE = re.compile( + r'\.Handle\("([^"]*)"' # path + r',\s*[^)]*\((\w+)\)\)' # wrapper(...handlerFunc)) + r'\.Methods\(([^)]+)\)', # .Methods(one or more methods) +) + +_METHOD_RE = re.compile(r'(?:http\.Method)?(\w+)') +_HTTP_METHODS = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"} + + +def _parse_methods(raw: str) -> list[str]: + """Split raw Methods(...) content into individual uppercase HTTP verbs. + + Filters against the set of known HTTP methods so that incidental + identifiers (handler names, constants, etc.) that happen to appear + inside Methods(...) don't produce spurious results. + """ + return [ + verb + for token in raw.split(",") + if (m := _METHOD_RE.search(token.strip())) + and (verb := m.group(1).upper()) in _HTTP_METHODS + ] + + +def _format_endpoint(path: str, handler: str, method: str) -> str: + return f"`{method.upper()} {path or '/'}` (`{handler}`)" + + +def _parse_endpoints(src: str) -> set[tuple[str, str, str]]: + """ + Parse Handle() registrations from a Go source file. + + Whitespace-collapses the entire file first so multi-line declarations + (e.g. the 18 in group.go) are matched as a single token sequence. + Returns {(path, handler, method)} tuples. + """ + blob = " ".join(src.split()) + endpoints: set[tuple[str, str, str]] = set() + for m in _HANDLE_RE.finditer(blob): + path, handler, methods_raw = m.group(1), m.group(2), m.group(3) + for method in _parse_methods(methods_raw): + endpoints.add((path or "/", handler, method)) + return endpoints + + +def check_api(patches: dict[str, str]) -> CheckResult: + """ + Detect API endpoint additions/removals in the api4/ directory. + + Compares full-file snapshots at BASE_SHA and HEAD_SHA via set arithmetic, + so multi-line and multi-method registrations are handled correctly. + """ + result = CheckResult(label="API Changes (`api4`)") + + api4_patches = { + fname: patch + for fname, patch in patches.items() + if fname.startswith("server/channels/api4/") and fname.endswith(".go") + } + if not api4_patches: + return result + + added_eps: set[tuple[str, str, str]] = set() + removed_eps: set[tuple[str, str, str]] = set() + + for fname, patch in api4_patches.items(): + base_eps = _parse_endpoints(file_at(BASE_SHA, fname)) + head_eps = _parse_endpoints(file_at(HEAD_SHA, fname)) + added_eps |= head_eps - base_eps + removed_eps |= base_eps - head_eps + + # Anchor the check to avoid false positives from unrelated source text + if re.search(r"^new file mode \d+", patch, re.MULTILINE): + result.changes.append(f"🆕 New API file: `{fname.split('/')[-1]}`") + if re.search(r"^deleted file mode \d+", patch, re.MULTILINE): + result.changes.append(f"🗑️ Removed API file: `{fname.split('/')[-1]}`") + + result.additions = sorted(_format_endpoint(p, h, m) for p, h, m in added_eps) + result.removals = sorted(_format_endpoint(p, h, m) for p, h, m in removed_eps) + return result + + +# ── Checker 3 — audit_events.go ─────────────────────────────────────────────── + +_AUDIT_EVENT_PATH = "server/public/model/audit_events.go" +_AUDIT_CONST_RE = re.compile(r"^\t(AuditEvent\w+)\s*=") + + +def _parse_audit_events(src: str) -> set[str]: + return {m.group(1) for line in src.splitlines() if (m := _AUDIT_CONST_RE.match(line))} + + +def check_audit_events(patches: dict[str, str]) -> CheckResult: + """ + Detect AuditEvent* constant additions/removals. + + Uses full-file snapshots at BASE_SHA/HEAD_SHA so reorderings and + cross-constant name collisions don't produce false results. + """ + result = CheckResult(label="Audit Log Event Changes") + if _AUDIT_EVENT_PATH not in patches: + return result + + base_events = _parse_audit_events(file_at(BASE_SHA, _AUDIT_EVENT_PATH)) + head_events = _parse_audit_events(file_at(HEAD_SHA, _AUDIT_EVENT_PATH)) + + result.additions = sorted(f"`{e}`" for e in head_events - base_events) + result.removals = sorted(f"`{e}`" for e in base_events - head_events) + return result + + +# ── Checker 4 — Dockerfile.buildenv (Go version) ────────────────────────────── + +# The Go version lives in the base image tag, e.g.: +# FROM mattermost/golang-bullseye:1.25.8@sha256:... +_DOCKERFILE_PATH = "server/build/Dockerfile.buildenv" +_IMAGE_VER_RE = re.compile(r"^FROM \S+:([0-9]+\.[0-9]+(?:\.[0-9]+)?)") + + +def _parse_go_version(src: str) -> Optional[str]: + for line in src.splitlines(): + m = _IMAGE_VER_RE.match(line.strip()) + if m: + return m.group(1) + return None + + +def check_go_version(patches: dict[str, str]) -> CheckResult: + """ + Detect Go runtime version changes via the base image tag. + + Uses full-file snapshots so the version is read from the actual file + state at each ref rather than reconstructed from patch lines. + """ + result = CheckResult(label="Go Runtime Version") + if _DOCKERFILE_PATH not in patches: + return result + + old_ver = _parse_go_version(file_at(BASE_SHA, _DOCKERFILE_PATH)) + new_ver = _parse_go_version(file_at(HEAD_SHA, _DOCKERFILE_PATH)) + + if old_ver and new_ver and old_ver != new_ver: + result.changes.append(f"Go updated: `{old_ver}` → `{new_ver}`") + elif new_ver and not old_ver: + result.additions.append(f"`{new_ver}`") + return result + + +# ── PR description helpers ───────────────────────────────────────────────────── + +# Matches lines that were auto-generated by this script so they can be stripped +# before re-injecting a fresh set on subsequent commits. +_AUTO_LINE_RE = re.compile( + r"^(Added|Removed) `[^`]+`.*(configuration setting|API endpoint|audit log event)\." + r"|^Go runtime updated from \S+ to \S+\." + r"|^Go runtime set to `[^`]+`\." + r"|^🆕 New API file:" + r"|^🗑️ Removed API file:" +) + +# Matches placeholder content inside a release-note fence that means "nothing +# to report yet" (e.g. NONE, N/A, ---). When detected, we replace the +# placeholder rather than appending alongside it. +_PLACEHOLDER_RE = re.compile(r"^\s*(?:NONE|N/?A|-+)\s*$", re.IGNORECASE) + + +def _format_lines(result: CheckResult) -> list[str]: + """Produce natural-language lines for one checker result.""" + lines = [] + + if "`config.json`" in result.label: + for item in result.additions: + lines.append(f"Added {item} configuration setting.") + for item in result.removals: + lines.append(f"Removed {item} configuration setting.") + + elif "API Changes" in result.label: + for item in result.additions: + lines.append(f"Added {item} API endpoint.") + for item in result.removals: + lines.append(f"Removed {item} API endpoint.") + lines.extend(result.changes) # new/deleted file entries + + elif "Audit" in result.label: + for item in result.additions: + lines.append(f"Added {item} audit log event.") + for item in result.removals: + lines.append(f"Removed {item} audit log event.") + + elif "Go Runtime" in result.label: + for item in result.additions: + # item is e.g. "`1.22`" — strip backticks for the prose form + lines.append(f"Go runtime set to {item}.") + for c in result.changes: + # c arrives as "Go updated: `1.21` → `1.22`" — rewrite it + m = re.search(r"`([^`]+)`\s*→\s*`([^`]+)`", c) + if m: + lines.append(f"Go runtime updated from {m.group(1)} to {m.group(2)}.") + else: + lines.append(c) + + return lines + + +def build_pr_note(results: list[CheckResult]) -> str: + """Assemble all findings into a clean plain-text block.""" + lines = [] + for r in results: + if r.has_findings(): + lines.extend(_format_lines(r)) + return "\n".join(lines) + + +def strip_old_note(body: str) -> str: + """ + Remove previously auto-generated lines from the PR description. + + Primary path — lines inside the ```release-note ... ``` fence. + Fallback path — auto-generated lines that were appended outside any fence + (e.g. via the ## Release Notes section on earlier runs). + + Lines are identified by pattern rather than visible markers, so the PR + description stays clean for human readers. + """ + def _clean_fence(m: re.Match) -> str: + open_tag, content, close_tag = m.group(1), m.group(2), m.group(3) + cleaned_lines = [ + line for line in content.split("\n") + if not _AUTO_LINE_RE.match(line.strip()) + ] + return open_tag + "\n".join(cleaned_lines) + close_tag + + cleaned = re.sub( + r"(```release-note)(.*?)(```)", + _clean_fence, + body or "", + flags=re.DOTALL | re.IGNORECASE, + ) + + # Fallback: strip any auto-generated lines that appear outside a fence + # (written by an older version of this script or via the header-inject path). + cleaned_lines = [ + line for line in cleaned.splitlines() + if not _AUTO_LINE_RE.match(line.strip()) + ] + return "\n".join(cleaned_lines).rstrip() + + +def inject_note(body: str, note: str) -> str: + """ + Insert `note` using this priority order: + + 1. INSIDE the ```release-note block, before its closing ``` + (Mattermost convention — keeps everything in one place for reviewers) + 2. After a recognised release-notes section header (## Release Notes, etc.) + 3. Fallback: append a new ## Release Notes section at the end + """ + body = strip_old_note(body) + if not note: + return body + + # 1. Mattermost-style ```release-note ... ``` block — inject INSIDE the fence. + # If the fence currently contains only a placeholder (NONE / N/A / ---), + # replace the placeholder rather than appending alongside it. + release_note_block = re.search( + r"(```release-note)(.*?)(```)", + body, + flags=re.DOTALL | re.IGNORECASE, + ) + if release_note_block: + open_tag = release_note_block.group(1) + content = release_note_block.group(2) # everything between the fences + close_tag = release_note_block.group(3) + block_start = release_note_block.start() + block_end = release_note_block.end() + + # Strip leading/trailing newlines inside the fence for comparison + inner = content.strip() + if _PLACEHOLDER_RE.match(inner): + # Replace the entire fence with a fresh one + new_block = f"{open_tag}\n{note}\n{close_tag}" + else: + # Append before the closing fence + new_block = open_tag + content + note + "\n" + close_tag + + return body[:block_start] + new_block + body[block_end:] + + # 2. Markdown section headers + for header in ["## Release Notes", "## Changelog", "## What Changed", "## What's Changed"]: + if header.lower() in body.lower(): + idx = body.lower().index(header.lower()) + len(header) + return body[:idx] + "\n\n" + note + body[idx:] + + # 3. Fallback — append + return body + "\n\n## Release Notes\n\n" + note + + +# ── GitHub API ───────────────────────────────────────────────────────────────── + +def get_pr_body() -> str: + r = requests.get( + f"{BASE_URL}/repos/{REPO}/pulls/{PR_NUMBER}", + headers=HEADERS, + timeout=_TIMEOUT, + ) + r.raise_for_status() + return r.json().get("body") or "" + + +def update_pr_body(new_body: str) -> None: + r = requests.patch( + f"{BASE_URL}/repos/{REPO}/pulls/{PR_NUMBER}", + headers=HEADERS, + json={"body": new_body}, + timeout=_TIMEOUT, + ) + r.raise_for_status() + + +# ── Main ─────────────────────────────────────────────────────────────────────── + +def main(): + print(f"📋 PR #{PR_NUMBER} | base {BASE_SHA[:8]} → head {HEAD_SHA[:8]}") + print("🔍 Collecting diffs …") + + full_patch = get_full_patch() + if not full_patch.strip(): + print("ℹ️ No changes in watched paths. Nothing to do.") + return + + patches = split_patch_by_file(full_patch) + print(f" {len(patches)} file(s) changed in watched paths.\n") + + # Run all checkers + checkers = [ + check_config, + check_api, + check_audit_events, + check_go_version, + ] + results: list[CheckResult] = [fn(patches) for fn in checkers] + + for r in results: + if r.has_findings(): + print(f" ✅ {r.label}") + if r.additions: + print(f" Added: {', '.join(r.additions)}") + if r.removals: + print(f" Removed: {', '.join(r.removals)}") + for c in r.changes: + print(f" {c}") + else: + print(f" – {r.label}: no changes") + + note = build_pr_note(results) + if not note: + print("\nℹ️ No notable changes found across all checkers.") + return + + print("\n🔄 Fetching PR description …") + body = get_pr_body() + new_body = inject_note(body, note) + + if new_body == body: + print("ℹ️ PR description already up to date — no changes needed.") + return + + update_pr_body(new_body) + print(f"✅ PR #{PR_NUMBER} description updated.") + + +if __name__ == "__main__": + try: + main() + except subprocess.CalledProcessError as e: + print(f"❌ git diff failed:\n{e.stderr}", file=sys.stderr) + sys.exit(1) + except requests.HTTPError as e: + # Avoid dumping the full response body (can be large / noisy). + # status + reason gives enough context for debugging (e.g. "403 Forbidden"). + reason = e.response.reason or "unknown" + print(f"❌ GitHub API error: {e.response.status_code} {reason}", file=sys.stderr) + sys.exit(1) diff --git a/.github/workflows/config-change-checker.yml b/.github/workflows/config-change-checker.yml new file mode 100644 index 00000000000..0772baaf0e6 --- /dev/null +++ b/.github/workflows/config-change-checker.yml @@ -0,0 +1,66 @@ +# .github/workflows/config-change-checker.yml +# +# Automatically detects notable additions/removals across four source files +# and appends structured release-note entries to the PR description under +# the "## Release Notes" section. +# +# Tracked files / directories: +# • server/public/model/config.go — config struct field changes +# • server/channels/api4/ — API endpoint additions/removals +# • server/public/model/audit_events.go — audit log event constant changes +# • server/build/Dockerfile.buildenv — Go runtime version changes +# +# No secrets needed — uses the built-in GITHUB_TOKEN. + +name: Config Change Checker + +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'server/public/model/config.go' + - 'server/channels/api4/**' + - 'server/public/model/audit_events.go' + - 'server/build/Dockerfile.buildenv' + +# Cancel any in-progress run for the same PR when a new commit is pushed. +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + check-release-notes: + name: Detect release-note-worthy changes + runs-on: ubuntu-latest + # Skip bot-authored PRs (Dependabot, mattermost-bot, etc.) — they will + # not touch these paths intentionally and cannot receive description updates + # via GITHUB_TOKEN anyway (fork-like restrictions apply to most bots). + if: github.event.pull_request.user.type != 'Bot' + + permissions: + pull-requests: write # needed to update the PR description + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + # Fetch enough history to diff against the base branch + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: '3.11' + + - name: Install dependencies + run: pip install requests==2.32.3 --quiet + + - name: Detect changes and update PR description + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + BASE_SHA: ${{ github.event.pull_request.base.sha }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + REPO: ${{ github.repository }} + run: python3 .github/scripts/check_config_changes_ci.py From 92f6870a2b97636876a3a6ebedf7ddac962a1e9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Espino=20Garc=C3=ADa?= Date: Tue, 19 May 2026 15:22:04 +0200 Subject: [PATCH 31/80] Add "last used" field for incoming webhooks (#36416) * Add "last used" field for incoming webhooks * Address feedback * Rename migrations * Fix web lint --- api/v4/source/definitions.yaml | 4 ++ server/channels/app/webhook.go | 12 +++++- server/channels/db/migrations/migrations.list | 2 + ...add_lastused_to_incoming_webhooks.down.sql | 1 + ...4_add_lastused_to_incoming_webhooks.up.sql | 1 + .../store/localcachelayer/webhook_layer.go | 10 +++++ .../channels/store/sqlstore/webhook_store.go | 18 +++++++- server/channels/store/store.go | 1 + .../store/storetest/mocks/WebhookStore.go | 18 ++++++++ .../channels/store/storetest/webhook_store.go | 41 +++++++++++++++++++ server/channels/web/webhook_test.go | 4 ++ server/public/model/incoming_webhook.go | 2 + .../installed_incoming_webhook.test.tsx.snap | 9 ++++ .../abstract_incoming_hook.test.tsx | 1 + .../abstract_incoming_webhook.tsx | 1 + .../edit_incoming_webhook.test.tsx | 1 + .../installed_incoming_webhook.test.tsx | 21 ++++++++++ .../installed_incoming_webhook.tsx | 18 ++++++++ webapp/channels/src/i18n/en.json | 2 + webapp/channels/src/utils/test_helper.ts | 1 + webapp/platform/types/src/integrations.ts | 1 + 21 files changed, 166 insertions(+), 3 deletions(-) create mode 100644 server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql create mode 100644 server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql diff --git a/api/v4/source/definitions.yaml b/api/v4/source/definitions.yaml index 0968f660efe..ba31ba5325f 100644 --- a/api/v4/source/definitions.yaml +++ b/api/v4/source/definitions.yaml @@ -1151,6 +1151,10 @@ components: description: The time in milliseconds a incoming webhook was deleted type: integer format: int64 + last_used: + description: The time in milliseconds this incoming webhook was last used to post a message + type: integer + format: int64 channel_id: description: The ID of a public channel or private group that receives the webhook payloads diff --git a/server/channels/app/webhook.go b/server/channels/app/webhook.go index 8ee15553f59..92775e70d50 100644 --- a/server/channels/app/webhook.go +++ b/server/channels/app/webhook.go @@ -445,6 +445,7 @@ func (a *App) UpdateIncomingWebhook(oldHook, updatedHook *model.IncomingWebhook) updatedHook.UpdateAt = model.GetMillis() updatedHook.TeamId = oldHook.TeamId updatedHook.DeleteAt = oldHook.DeleteAt + updatedHook.LastUsed = oldHook.LastUsed newWebhook, err := a.Srv().Store().Webhook().UpdateIncoming(updatedHook) if err != nil { @@ -903,7 +904,16 @@ func (a *App) HandleIncomingWebhook(rctx request.CTX, hookID string, req *model. } _, err := a.CreateWebhookPost(rctx, hook.UserId, channel, text, overrideUsername, overrideIconURL, req.IconEmoji, req.Props, webhookType, threadRootID, req.Priority) - return err + if err != nil { + return err + } + + now := model.GetMillis() + if nErr := a.Srv().Store().Webhook().UpdateIncomingLastUsed(hook.Id, now); nErr != nil { + rctx.Logger().Warn("Failed to update incoming webhook LastUsed", mlog.String("hook_id", hook.Id), mlog.Err(nErr)) + } + + return nil } func (a *App) CreateCommandWebhook(commandID string, args *model.CommandArgs) (*model.CommandWebhook, *model.AppError) { diff --git a/server/channels/db/migrations/migrations.list b/server/channels/db/migrations/migrations.list index 04fc52052f5..73a7ed10fd1 100644 --- a/server/channels/db/migrations/migrations.list +++ b/server/channels/db/migrations/migrations.list @@ -363,3 +363,5 @@ channels/db/migrations/postgres/000182_create_channel_join_requests_channel_stat channels/db/migrations/postgres/000182_create_channel_join_requests_channel_status_index.up.sql channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.down.sql channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql +channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql +channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql diff --git a/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql b/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql new file mode 100644 index 00000000000..a0dc89dc227 --- /dev/null +++ b/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql @@ -0,0 +1 @@ +ALTER TABLE incomingwebhooks DROP COLUMN IF EXISTS lastused; diff --git a/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql b/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql new file mode 100644 index 00000000000..8031f7eeee4 --- /dev/null +++ b/server/channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql @@ -0,0 +1 @@ +ALTER TABLE incomingwebhooks ADD COLUMN IF NOT EXISTS lastused bigint NOT NULL DEFAULT 0; diff --git a/server/channels/store/localcachelayer/webhook_layer.go b/server/channels/store/localcachelayer/webhook_layer.go index 9a753b4b4ff..75d11d89671 100644 --- a/server/channels/store/localcachelayer/webhook_layer.go +++ b/server/channels/store/localcachelayer/webhook_layer.go @@ -87,3 +87,13 @@ func (s LocalCacheWebhookStore) PermanentDeleteIncomingByChannel(channelId strin s.ClearCaches() return nil } + +func (s LocalCacheWebhookStore) UpdateIncomingLastUsed(webhookID string, lastUsed int64) error { + err := s.WebhookStore.UpdateIncomingLastUsed(webhookID, lastUsed) + if err != nil { + return err + } + + s.InvalidateWebhookCache(webhookID) + return nil +} diff --git a/server/channels/store/sqlstore/webhook_store.go b/server/channels/store/sqlstore/webhook_store.go index 4a7857f878f..5b3f4feca41 100644 --- a/server/channels/store/sqlstore/webhook_store.go +++ b/server/channels/store/sqlstore/webhook_store.go @@ -46,6 +46,7 @@ func newSqlWebhookStore(sqlStore *SqlStore, metrics einterfaces.MetricsInterface "Username", "IconURL", "ChannelLocked", + "LastUsed", ). From("IncomingWebhooks") @@ -88,9 +89,9 @@ func (s SqlWebhookStore) SaveIncoming(webhook *model.IncomingWebhook) (*model.In } if _, err := s.GetMaster().NamedExec(`INSERT INTO IncomingWebhooks - (Id, CreateAt, UpdateAt, DeleteAt, UserId, ChannelId, TeamId, DisplayName, Description, Username, IconURL, ChannelLocked) + (Id, CreateAt, UpdateAt, DeleteAt, UserId, ChannelId, TeamId, DisplayName, Description, Username, IconURL, ChannelLocked, LastUsed) VALUES - (:Id, :CreateAt, :UpdateAt, :DeleteAt, :UserId, :ChannelId, :TeamId, :DisplayName, :Description, :Username, :IconURL, :ChannelLocked)`, webhook); err != nil { + (:Id, :CreateAt, :UpdateAt, :DeleteAt, :UserId, :ChannelId, :TeamId, :DisplayName, :Description, :Username, :IconURL, :ChannelLocked, :LastUsed)`, webhook); err != nil { return nil, errors.Wrapf(err, "failed to save IncomingWebhook with id=%s", webhook.Id) } @@ -111,6 +112,19 @@ func (s SqlWebhookStore) UpdateIncoming(hook *model.IncomingWebhook) (*model.Inc return hook, nil } +func (s SqlWebhookStore) UpdateIncomingLastUsed(webhookID string, lastUsed int64) error { + _, err := s.GetMaster().Exec( + `UPDATE IncomingWebhooks SET LastUsed = ? WHERE Id = ? AND DeleteAt = 0`, + lastUsed, + webhookID, + ) + if err != nil { + return errors.Wrapf(err, "failed to update LastUsed for IncomingWebhook id=%s", webhookID) + } + + return nil +} + func (s SqlWebhookStore) GetIncoming(id string, allowFromCache bool) (*model.IncomingWebhook, error) { var webhook model.IncomingWebhook diff --git a/server/channels/store/store.go b/server/channels/store/store.go index c93eb8a2725..04715debd3d 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -655,6 +655,7 @@ type WebhookStore interface { GetIncomingByTeam(teamID string, offset, limit int) ([]*model.IncomingWebhook, error) GetIncomingByTeamByUser(teamID string, userID string, offset, limit int) ([]*model.IncomingWebhook, error) UpdateIncoming(webhook *model.IncomingWebhook) (*model.IncomingWebhook, error) + UpdateIncomingLastUsed(webhookID string, lastUsed int64) error GetIncomingByChannel(channelID string) ([]*model.IncomingWebhook, error) DeleteIncoming(webhookID string, timestamp int64) error PermanentDeleteIncomingByChannel(channelID string) error diff --git a/server/channels/store/storetest/mocks/WebhookStore.go b/server/channels/store/storetest/mocks/WebhookStore.go index 77dd12fd254..16f563681f7 100644 --- a/server/channels/store/storetest/mocks/WebhookStore.go +++ b/server/channels/store/storetest/mocks/WebhookStore.go @@ -668,6 +668,24 @@ func (_m *WebhookStore) UpdateIncoming(webhook *model.IncomingWebhook) (*model.I return r0, r1 } +// UpdateIncomingLastUsed provides a mock function with given fields: webhookID, lastUsed +func (_m *WebhookStore) UpdateIncomingLastUsed(webhookID string, lastUsed int64) error { + ret := _m.Called(webhookID, lastUsed) + + if len(ret) == 0 { + panic("no return value specified for UpdateIncomingLastUsed") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, int64) error); ok { + r0 = rf(webhookID, lastUsed) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateOutgoing provides a mock function with given fields: hook func (_m *WebhookStore) UpdateOutgoing(hook *model.OutgoingWebhook) (*model.OutgoingWebhook, error) { ret := _m.Called(hook) diff --git a/server/channels/store/storetest/webhook_store.go b/server/channels/store/storetest/webhook_store.go index a0884698bce..a026cbf1d04 100644 --- a/server/channels/store/storetest/webhook_store.go +++ b/server/channels/store/storetest/webhook_store.go @@ -18,6 +18,8 @@ import ( func TestWebhookStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("SaveIncoming", func(t *testing.T) { testWebhookStoreSaveIncoming(t, rctx, ss) }) t.Run("UpdateIncoming", func(t *testing.T) { testWebhookStoreUpdateIncoming(t, rctx, ss) }) + t.Run("UpdateIncomingPreservesLastUsed", func(t *testing.T) { testWebhookStoreUpdateIncomingPreservesLastUsed(t, rctx, ss) }) + t.Run("UpdateIncomingLastUsed", func(t *testing.T) { testWebhookStoreUpdateIncomingLastUsed(t, rctx, ss) }) t.Run("GetIncoming", func(t *testing.T) { testWebhookStoreGetIncoming(t, rctx, ss) }) t.Run("GetIncomingList", func(t *testing.T) { testWebhookStoreGetIncomingList(t, rctx, ss) }) t.Run("GetIncomingListByUser", func(t *testing.T) { testWebhookStoreGetIncomingListByUser(t, rctx, ss) }) @@ -73,6 +75,45 @@ func testWebhookStoreUpdateIncoming(t *testing.T, rctx request.CTX, ss store.Sto require.Equal(t, "TestHook", webhook.DisplayName, "display name is not updated") } +// testWebhookStoreUpdateIncomingPreservesLastUsed ensures the generic UpdateIncoming path does not +// overwrite LastUsed; only UpdateIncomingLastUsed should change that column. +func testWebhookStoreUpdateIncomingPreservesLastUsed(t *testing.T, rctx request.CTX, ss store.Store) { + o1 := buildIncomingWebhook() + saved, err := ss.Webhook().SaveIncoming(o1) + require.NoError(t, err) + require.Zero(t, saved.LastUsed, "new webhook should have LastUsed 0") + + lastUsed := model.GetMillis() + err = ss.Webhook().UpdateIncomingLastUsed(saved.Id, lastUsed) + require.NoError(t, err) + + withStaleLastUsed := *saved + withStaleLastUsed.DisplayName = "RenamedHook" + withStaleLastUsed.LastUsed = 0 + + _, err = ss.Webhook().UpdateIncoming(&withStaleLastUsed) + require.NoError(t, err) + + fromDB, err := ss.Webhook().GetIncoming(saved.Id, false) + require.NoError(t, err) + require.Equal(t, lastUsed, fromDB.LastUsed, "UpdateIncoming must not clear LastUsed when struct has LastUsed 0") + require.Equal(t, "RenamedHook", fromDB.DisplayName) +} + +func testWebhookStoreUpdateIncomingLastUsed(t *testing.T, rctx request.CTX, ss store.Store) { + o1 := buildIncomingWebhook() + o1, err := ss.Webhook().SaveIncoming(o1) + require.NoError(t, err) + + lastUsed := model.GetMillis() + err = ss.Webhook().UpdateIncomingLastUsed(o1.Id, lastUsed) + require.NoError(t, err) + + updated, err := ss.Webhook().GetIncoming(o1.Id, false) + require.NoError(t, err) + require.Equal(t, lastUsed, updated.LastUsed) +} + func testWebhookStoreGetIncoming(t *testing.T, rctx request.CTX, ss store.Store) { var err error diff --git a/server/channels/web/webhook_test.go b/server/channels/web/webhook_test.go index aab7cf79010..b7e33eb12e6 100644 --- a/server/channels/web/webhook_test.go +++ b/server/channels/web/webhook_test.go @@ -42,6 +42,10 @@ func TestIncomingWebhook(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + refreshed, appErr := th.App.GetIncomingWebhook(hook.Id) + require.Nil(t, appErr) + require.NotZero(t, refreshed.LastUsed) + payload = "payload={\"text\": \"\"}" resp, err = http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(payload)) require.NoError(t, err) diff --git a/server/public/model/incoming_webhook.go b/server/public/model/incoming_webhook.go index ac28a3260e2..2e2b3ded574 100644 --- a/server/public/model/incoming_webhook.go +++ b/server/public/model/incoming_webhook.go @@ -28,6 +28,7 @@ type IncomingWebhook struct { Username string `json:"username"` IconURL string `json:"icon_url"` ChannelLocked bool `json:"channel_locked"` + LastUsed int64 `json:"last_used"` } func (o *IncomingWebhook) Auditable() map[string]any { @@ -44,6 +45,7 @@ func (o *IncomingWebhook) Auditable() map[string]any { "username": o.Username, "icon_url:": o.IconURL, "channel_locked": o.ChannelLocked, + "last_used": o.LastUsed, } } diff --git a/webapp/channels/src/components/integrations/__snapshots__/installed_incoming_webhook.test.tsx.snap b/webapp/channels/src/components/integrations/__snapshots__/installed_incoming_webhook.test.tsx.snap index e300c6f1cef..9989a4c5336 100644 --- a/webapp/channels/src/components/integrations/__snapshots__/installed_incoming_webhook.test.tsx.snap +++ b/webapp/channels/src/components/integrations/__snapshots__/installed_incoming_webhook.test.tsx.snap @@ -69,6 +69,15 @@ exports[`components/integrations/InstalledIncomingWebhook should match snapshot Created by creator on Friday, August 11, 2017
+
+ + Never used + +
diff --git a/webapp/channels/src/components/integrations/abstract_incoming_hook.test.tsx b/webapp/channels/src/components/integrations/abstract_incoming_hook.test.tsx index e871af7bc57..b71264acc59 100644 --- a/webapp/channels/src/components/integrations/abstract_incoming_hook.test.tsx +++ b/webapp/channels/src/components/integrations/abstract_incoming_hook.test.tsx @@ -50,6 +50,7 @@ describe('components/integrations/AbstractIncomingWebhook', () => { username: '', icon_url: '', channel_locked: false, + last_used: 0, }; const enablePostUsernameOverride = true; const enablePostIconOverride = true; diff --git a/webapp/channels/src/components/integrations/abstract_incoming_webhook.tsx b/webapp/channels/src/components/integrations/abstract_incoming_webhook.tsx index 9f76a5a8f63..8f1d35ca5f4 100644 --- a/webapp/channels/src/components/integrations/abstract_incoming_webhook.tsx +++ b/webapp/channels/src/components/integrations/abstract_incoming_webhook.tsx @@ -142,6 +142,7 @@ export default class AbstractIncomingWebhook extends PureComponent delete_at: this.props.initialHook?.delete_at || 0, team_id: this.props.initialHook?.team_id || '', user_id: this.props.initialHook?.user_id || '', + last_used: this.props.initialHook?.last_used || 0, }; this.props.action(hook).then(() => this.setState({saving: false})); diff --git a/webapp/channels/src/components/integrations/edit_incoming_webhook/edit_incoming_webhook.test.tsx b/webapp/channels/src/components/integrations/edit_incoming_webhook/edit_incoming_webhook.test.tsx index c33d7457a70..e8209839a39 100644 --- a/webapp/channels/src/components/integrations/edit_incoming_webhook/edit_incoming_webhook.test.tsx +++ b/webapp/channels/src/components/integrations/edit_incoming_webhook/edit_incoming_webhook.test.tsx @@ -63,6 +63,7 @@ describe('components/integrations/EditIncomingWebhook', () => { username: 'username', icon_url: 'http://test/icon.png', channel_locked: false, + last_used: 0, }; const updateIncomingHook = jest.fn(); diff --git a/webapp/channels/src/components/integrations/installed_incoming_webhook.test.tsx b/webapp/channels/src/components/integrations/installed_incoming_webhook.test.tsx index ebe36f5abc4..898dd6b299d 100644 --- a/webapp/channels/src/components/integrations/installed_incoming_webhook.test.tsx +++ b/webapp/channels/src/components/integrations/installed_incoming_webhook.test.tsx @@ -15,6 +15,7 @@ describe('components/integrations/InstalledIncomingWebhook', () => { id: '9w96t4nhbfdiij64wfqors4i1r', channel_id: '1jiw9kphbjrntfyrm7xpdcya4o', create_at: 1502455422406, + last_used: 0, delete_at: 0, description: 'build status', display_name: 'build', @@ -163,4 +164,24 @@ describe('components/integrations/InstalledIncomingWebhook', () => { ); expect(container.querySelector('.item-details')).not.toBeNull(); }); + + test('should show Last used on when last_used is non-zero', () => { + const lastUsedMs = 1704067200000; + const hookWithLastUsed: IncomingWebhook = { + ...incomingWebhook, + last_used: lastUsedMs, + }; + + renderWithContext( + , + initialState, + ); + + expect(screen.getByText(/Last used on/i)).toBeInTheDocument(); + expect(screen.queryByText('Never used')).not.toBeInTheDocument(); + }); }); diff --git a/webapp/channels/src/components/integrations/installed_incoming_webhook.tsx b/webapp/channels/src/components/integrations/installed_incoming_webhook.tsx index d3e8a680ba6..970924bba29 100644 --- a/webapp/channels/src/components/integrations/installed_incoming_webhook.tsx +++ b/webapp/channels/src/components/integrations/installed_incoming_webhook.tsx @@ -181,6 +181,24 @@ export default class InstalledIncomingWebhook extends React.PureComponent /> +
+ + {incomingWebhook.last_used > 0 ? ( + + ) : ( + + )} + +
); diff --git a/webapp/channels/src/i18n/en.json b/webapp/channels/src/i18n/en.json index a8450a2e1f4..ba270f32a04 100644 --- a/webapp/channels/src/i18n/en.json +++ b/webapp/channels/src/i18n/en.json @@ -5109,6 +5109,8 @@ "installed_integrations.edit": "Edit", "installed_integrations.fromApp": "Managed by Apps Framework", "installed_integrations.hideSecret": "Hide Secret", + "installed_integrations.last_used": "Last used on {lastUsed, date, full}", + "installed_integrations.never_used": "Never used", "installed_integrations.regenSecret": "Regenerate Secret", "installed_integrations.regenToken": "Regenerate Token", "installed_integrations.showSecret": "Show Secret", diff --git a/webapp/channels/src/utils/test_helper.ts b/webapp/channels/src/utils/test_helper.ts index 8dacc70d202..cf3b21d2587 100644 --- a/webapp/channels/src/utils/test_helper.ts +++ b/webapp/channels/src/utils/test_helper.ts @@ -285,6 +285,7 @@ export class TestHelper { create_at: 1, update_at: 1, delete_at: 1, + last_used: 0, user_id: '', channel_id: '', team_id: '', diff --git a/webapp/platform/types/src/integrations.ts b/webapp/platform/types/src/integrations.ts index 3cdc5e592a7..b78cff05e8c 100644 --- a/webapp/platform/types/src/integrations.ts +++ b/webapp/platform/types/src/integrations.ts @@ -9,6 +9,7 @@ export type IncomingWebhook = { create_at: number; update_at: number; delete_at: number; + last_used: number; user_id: string; channel_id: string; team_id: string; From 7bb6fb347b75b99779351615bdb0968b94099c81 Mon Sep 17 00:00:00 2001 From: Asaad Mahmood Date: Tue, 19 May 2026 18:41:43 +0500 Subject: [PATCH 32/80] Fix AI toolbar separator visibility (#36356) Co-authored-by: Cursor Agent --- .../advanced_text_editor.tsx | 8 ++++++- .../formatting_bar/formatting_bar.test.tsx | 22 +++++++++++++++++++ .../formatting_bar/formatting_bar.tsx | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/webapp/channels/src/components/advanced_text_editor/advanced_text_editor.tsx b/webapp/channels/src/components/advanced_text_editor/advanced_text_editor.tsx index ddb52be5bc4..86bbc3b23ac 100644 --- a/webapp/channels/src/components/advanced_text_editor/advanced_text_editor.tsx +++ b/webapp/channels/src/components/advanced_text_editor/advanced_text_editor.tsx @@ -181,6 +181,7 @@ const AdvancedTextEditor = ({ const teammateDisplayName = useSelector((state: GlobalState) => (teammateId ? getDisplayName(state, teammateId) : '')); const showDndWarning = useSelector((state: GlobalState) => (teammateId ? getStatusForUserId(state, teammateId) === UserStatuses.DND : false)); const selectedPostFocussedAt = useSelector((state: GlobalState) => getSelectedPostFocussedAt(state)); + const aiActionMenuItems = useSelector((state: GlobalState) => state.plugins.components.AIActionMenuItem); const {available: aiRewriteEnabled} = useGetAgentsBridgeEnabled(); const canPost = useSelector((state: GlobalState) => { @@ -317,6 +318,7 @@ const AdvancedTextEditor = ({ rewriteMenuProps, isProcessing: rewriteIsProcessing, } = useRewrite(draft, handleDraftChange, textboxRef, focusTextbox, setServerError); + const hasAIActionsMenu = (aiActionMenuItems?.length ?? 0) > 0 || (aiRewriteEnabled && Boolean(rewriteMenuProps)); const isDisabled = Boolean(readOnlyChannel || (!enableSharedChannelsDMs && isDMOrGMRemote) || rewriteIsProcessing); const [attachmentPreview, fileUploadJSX] = useUploadFiles( @@ -707,6 +709,10 @@ const AdvancedTextEditor = ({ }, [handleDraftChange, draft]); const aiActionsMenu = useMemo(() => { + if (!hasAIActionsMenu) { + return null; + } + return ( ); - }, [draft, getSelectedText, updateText, channelId, location, rewriteMenuProps, aiRewriteEnabled]); + }, [draft, getSelectedText, updateText, channelId, location, rewriteMenuProps, aiRewriteEnabled, hasAIActionsMenu]); const formattingBar = ( { expect(fireEvent.mouseDown(screen.getByLabelText('code'))).toBe(false); }); + + test('should only render separator before bold when AI actions menu is present', () => { + jest.spyOn(Hooks, 'useFormattingBarControls').mockReturnValue({layoutMode: LayoutModes.Wide, ...splitFormattingBarControls('wide')}); + + const {container, rerender} = renderWithContext( + {'AI Actions'}} + />, + ); + + expect(container.querySelectorAll('[data-testid="formatting-bar-separator"]')).toHaveLength(2); + + rerender( + , + ); + + expect(container.querySelectorAll('[data-testid="formatting-bar-separator"]')).toHaveLength(1); + }); }); diff --git a/webapp/channels/src/components/advanced_text_editor/formatting_bar/formatting_bar.tsx b/webapp/channels/src/components/advanced_text_editor/formatting_bar/formatting_bar.tsx index aa006e19cca..2e389136bb8 100644 --- a/webapp/channels/src/components/advanced_text_editor/formatting_bar/formatting_bar.tsx +++ b/webapp/channels/src/components/advanced_text_editor/formatting_bar/formatting_bar.tsx @@ -16,7 +16,7 @@ import type {ApplyMarkdownOptions, MarkdownMode} from 'utils/markdown/apply_mark import FormattingIcon, {IconContainer} from './formatting_icon'; import {LayoutModes, useFormattingBarControls} from './hooks'; -export const Separator = styled.div` +export const Separator = styled.div.attrs({'data-testid': 'formatting-bar-separator'})` display: block; position: relative; width: 1px; From 5cd26002d3e91da57ab565af9941f12c7ecfd407 Mon Sep 17 00:00:00 2001 From: Maria A Nunez Date: Tue, 19 May 2026 09:55:31 -0400 Subject: [PATCH 33/80] Hide Download Apps link when running in Desktop app (#36614) * Hide Download Apps UI when running in Desktop app Co-authored-by: Maria A Nunez * Fix ESLint import order for Desktop app visibility changes Co-authored-by: Maria A Nunez --------- Co-authored-by: Cursor Agent --- .../product_menu_list.test.tsx | 19 +++++++++ .../product_menu_list/product_menu_list.tsx | 3 +- .../mobile_sidebar_right_items.test.tsx | 28 +++++++++++-- .../mobile_sidebar_right_items.tsx | 4 +- .../onboarding_tasklist_completed.test.tsx | 21 ++++++++++ .../onboarding_tasklist_completed.tsx | 39 ++++++++++--------- .../onboarding_tasks_manager.test.tsx | 20 ++++++++++ .../onboarding_tasks_manager.tsx | 6 +++ 8 files changed, 117 insertions(+), 23 deletions(-) diff --git a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.test.tsx b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.test.tsx index 2ddcec4280c..17a0b6cc39d 100644 --- a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.test.tsx +++ b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.test.tsx @@ -3,6 +3,7 @@ import React from 'react'; +import * as UserAgent from '@mattermost/shared/utils/user_agent'; import type {UserProfile} from '@mattermost/types/users'; import type {DeepPartial} from '@mattermost/types/utilities'; @@ -14,6 +15,11 @@ import type {GlobalState} from 'types/store'; import ProductMenuList from './product_menu_list'; import type {Props as ProductMenuListProps} from './product_menu_list'; +const isDesktopAppMock = jest.mocked(UserAgent.isDesktopApp); + +jest.mock('@mattermost/shared/utils/user_agent', () => ({ + isDesktopApp: jest.fn(() => false), +})); jest.mock('components/widgets/menu/menu_items/menu_cloud_trial', () => () => null); jest.mock('components/widgets/menu/menu_items/menu_item_cloud_limit', () => () => null); jest.mock('components/permissions_gates/system_permission_gate', () => ({children}: {children: React.ReactNode}) => <>{children}); @@ -238,4 +244,17 @@ describe('components/global/product_switcher_menu', () => { expect(container).toMatchSnapshot(); }); }); + + test('shows Download Apps link when appDownloadLink configured and not in desktop app', () => { + isDesktopAppMock.mockReturnValue(false); + const {container} = renderWithContext(, adminState); + expect(container.querySelector('#nativeAppLink')).not.toBeNull(); + }); + + test('hides Download Apps link when in desktop app', () => { + isDesktopAppMock.mockReturnValue(true); + const {container} = renderWithContext(, adminState); + expect(container.querySelector('#nativeAppLink')).toBeNull(); + isDesktopAppMock.mockReturnValue(false); + }); }); diff --git a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.tsx b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.tsx index 71b78712ac1..9df7b1cc883 100644 --- a/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.tsx +++ b/webapp/channels/src/components/global_header/left_controls/product_menu/product_menu_list/product_menu_list.tsx @@ -13,6 +13,7 @@ import { ViewGridPlusOutlineIcon, WebhookIncomingIcon, } from '@mattermost/compass-icons/components'; +import {isDesktopApp} from '@mattermost/shared/utils/user_agent'; import type {UserProfile} from '@mattermost/types/users'; import {Permissions} from 'mattermost-redux/constants'; @@ -208,7 +209,7 @@ const ProductMenuList = (props: Props): JSX.Element | null => { } diff --git a/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.test.tsx b/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.test.tsx index ccd287426f0..54fb1f5991e 100644 --- a/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.test.tsx +++ b/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.test.tsx @@ -3,6 +3,8 @@ import React from 'react'; +import * as UserAgent from '@mattermost/shared/utils/user_agent'; + import {Permissions} from 'mattermost-redux/constants'; import type {MockIntl} from 'tests/helpers/intl-test-helper'; @@ -11,6 +13,12 @@ import {renderWithContext, screen} from 'tests/react_testing_utils'; import {MobileSidebarRightItems} from './mobile_sidebar_right_items'; import type {Props} from './mobile_sidebar_right_items'; +const isDesktopAppMock = jest.mocked(UserAgent.isDesktopApp); + +jest.mock('@mattermost/shared/utils/user_agent', () => ({ + isDesktopApp: jest.fn(() => false), +})); + describe('MobileSidebarRightItems', () => { const defaultProps: Props = { teamId: 'team-id', @@ -160,14 +168,28 @@ describe('MobileSidebarRightItems', () => { expect(screen.getByText('Help')).toBeInTheDocument(); }); - test('should show report link when provided', () => { + test('should show Download Apps link when appDownloadLink is set and not in desktop app', () => { + isDesktopAppMock.mockReturnValue(false); renderWithContext( , defaultState, ); - expect(screen.getByText('Report a Problem')).toBeInTheDocument(); + expect(screen.getByText('Download Apps')).toBeInTheDocument(); + }); + + test('should hide Download Apps link when in desktop app', () => { + isDesktopAppMock.mockReturnValue(true); + renderWithContext( + , + defaultState, + ); + expect(screen.queryByText('Download Apps')).not.toBeInTheDocument(); + isDesktopAppMock.mockReturnValue(false); }); }); diff --git a/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.tsx b/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.tsx index 48cf7eb6b64..f6e6aafb3be 100644 --- a/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.tsx +++ b/webapp/channels/src/components/mobile_sidebar_right/mobile_sidebar_right_items/mobile_sidebar_right_items.tsx @@ -5,6 +5,8 @@ import React from 'react'; import {injectIntl} from 'react-intl'; import type {WrappedComponentProps} from 'react-intl'; +import {isDesktopApp} from '@mattermost/shared/utils/user_agent'; + import {Permissions} from 'mattermost-redux/constants'; import {emitUserLoggedOutEvent} from 'actions/global_actions'; @@ -407,7 +409,7 @@ export class MobileSidebarRightItems extends React.PureComponent { /> ({ + isDesktopApp: jest.fn(() => false), +})); + jest.mock('mattermost-redux/actions/admin', () => ({ ...jest.requireActual('mattermost-redux/actions/admin'), getPrevTrialLicense: () => ({type: 'MOCK_GET_PREV_TRIAL_LICENSE'}), @@ -67,4 +75,17 @@ describe('components/onboarding_tasklist/onboarding_tasklist_completed.tsx', () await userEvent.click(noThanksLink[0]); expect(dismissMockFn).toHaveBeenCalledTimes(1); }); + + test('displays download apps link when not in desktop app', () => { + isDesktopAppMock.mockReturnValue(false); + const {container} = renderWithContext(, initialState); + expect(container.querySelectorAll('.download-apps')).toHaveLength(1); + }); + + test('hides download apps link when in desktop app', () => { + isDesktopAppMock.mockReturnValue(true); + const {container} = renderWithContext(, initialState); + expect(container.querySelectorAll('.download-apps')).toHaveLength(0); + isDesktopAppMock.mockReturnValue(false); + }); }); diff --git a/webapp/channels/src/components/onboarding_tasklist/onboarding_tasklist_completed.tsx b/webapp/channels/src/components/onboarding_tasklist/onboarding_tasklist_completed.tsx index 7d3c9c37365..f3bcae1ce6d 100644 --- a/webapp/channels/src/components/onboarding_tasklist/onboarding_tasklist_completed.tsx +++ b/webapp/channels/src/components/onboarding_tasklist/onboarding_tasklist_completed.tsx @@ -7,6 +7,7 @@ import {useSelector, useDispatch} from 'react-redux'; import {CSSTransition} from 'react-transition-group'; import styled from 'styled-components'; +import {isDesktopApp} from '@mattermost/shared/utils/user_agent'; import type {GlobalState} from '@mattermost/types/store'; import {getPrevTrialLicense} from 'mattermost-redux/actions/admin'; @@ -212,24 +213,26 @@ const Completed = (props: Props): JSX.Element => { /> )} -
- - ( - - {msg} - - ), - }} - /> - -
+ {!isDesktopApp() && ( +
+ + ( + + {msg} + + ), + }} + /> + +
+ )} {showStartTrialBtn &&
({ + isDesktopApp: jest.fn(() => false), +})); + const WrapperComponent = (): JSX.Element => { const taskList = useTasksList(); return ( @@ -87,4 +95,16 @@ describe('onboarding tasks manager', () => { // verify visit_system_console and start_trial were removed expect(screen.queryByText('invite_people')).not.toBeInTheDocument(); }); + + it('Removes download_app task when running in desktop app', () => { + isDesktopAppMock.mockReturnValue(true); + renderWithContext( + , + initialState, + ); + + expect(screen.getAllByRole('listitem')).toHaveLength(5); + expect(screen.queryByText('download_app')).not.toBeInTheDocument(); + isDesktopAppMock.mockReturnValue(false); + }); }); diff --git a/webapp/channels/src/components/onboarding_tasks/onboarding_tasks_manager.tsx b/webapp/channels/src/components/onboarding_tasks/onboarding_tasks_manager.tsx index 2ef52001d99..1483cd77503 100644 --- a/webapp/channels/src/components/onboarding_tasks/onboarding_tasks_manager.tsx +++ b/webapp/channels/src/components/onboarding_tasks/onboarding_tasks_manager.tsx @@ -6,6 +6,8 @@ import {useIntl} from 'react-intl'; import {useDispatch, useSelector} from 'react-redux'; import {matchPath, useLocation} from 'react-router-dom'; +import {isDesktopApp} from '@mattermost/shared/utils/user_agent'; + import {savePreferences} from 'mattermost-redux/actions/preferences'; import {getCurrentUserId} from 'mattermost-redux/selectors/entities/common'; import {getLicense} from 'mattermost-redux/selectors/entities/general'; @@ -138,6 +140,10 @@ export const useTasksList = () => { delete list.INVITE_PEOPLE; } + if (isDesktopApp()) { + delete list.DOWNLOAD_APP; + } + return Object.values(list); }; From 5566604e030b8e3085c0b1a708e739557a0c394e Mon Sep 17 00:00:00 2001 From: Doug Lauder Date: Tue, 19 May 2026 10:12:00 -0400 Subject: [PATCH 34/80] MM-68838: Ping a restored plugin remote immediately on re-register (#36592) * MM-68838: ping restored plugin remote immediately on re-register RegisterPluginForSharedChannels' restore branch updated the row but did not call PingNow, leaving the restored remote offline until the next pingLoop tick (up to PingFreq, default 1 minute). The new-connection branch already calls PingNow; the restore branch now mirrors it so sync attempts immediately after a plugin restart no longer fail with "offline remote cluster". * MM-68838: gob-encode error returns in apiRPCServer.ReceiveSharedChannelAttachmentSyncMsg The apiRPCServer wrapper for ReceiveSharedChannelAttachmentSyncMsg assigned the hook's error return directly to the gob-encoded response struct. When the framework's App.ReceiveSharedChannelAttachmentSyncMsg returned an error wrapped with %w (*fmt.wrapError, an unexported type), gob refused to encode it and the RPC server broke the connection with "type not registered for interface: fmt.wrapError". Every subsequent plugin/server RPC call then returned the zero-value response struct, causing plugins that dereferenced the nil returns to crash. Apply the existing encodableError() helper so the returned error becomes a gob-safe ErrorString, matching every other apiRPCServer method in this file. --- server/channels/app/remote_cluster.go | 57 ++++++- server/channels/app/remote_cluster_test.go | 180 +++++++++++++++++++++ server/public/plugin/client_rpc.go | 1 + server/public/plugin/client_rpc_test.go | 46 ++++++ 4 files changed, 279 insertions(+), 5 deletions(-) create mode 100644 server/public/plugin/client_rpc_test.go diff --git a/server/channels/app/remote_cluster.go b/server/channels/app/remote_cluster.go index 4f7b144b2fe..b87ba525eaf 100644 --- a/server/channels/app/remote_cluster.go +++ b/server/channels/app/remote_cluster.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "fmt" "net/http" + "time" "github.com/pkg/errors" @@ -19,6 +20,34 @@ import ( "github.com/mattermost/mattermost/server/public/shared/request" ) +// pluginRemoteInitialPingDelay is how long the framework waits after +// RegisterPluginForSharedChannels returns before firing the first ping to +// the newly created or restored plugin remote. The delay gives the +// calling plugin a chance to record the returned RemoteId in its own +// state, so the synchronous OnSharedChannelsPing hook the framework +// invokes can resolve the remote. Without the delay, the first ping for +// every freshly registered SiteURL fails and the remote stays offline +// until the periodic pingLoop refreshes it (up to PingFreq, default 1 +// minute). Declared as a var, not const, so tests can shorten it. +var pluginRemoteInitialPingDelay = 5 * time.Second + +// schedulePluginRemoteInitialPing schedules a single deferred ping for a +// freshly registered or restored plugin remote. The goroutine is launched +// via Server.Go so it cannot outlive the server. The remote is re-read +// before the ping fires because the plugin may have unregistered it +// inside the delay window; pinging a soft-deleted row is harmless but +// produces a stray "ping failed" warning. +func (a *App) schedulePluginRemoteInitialPing(rcService remotecluster.RemoteClusterServiceIFace, rc *model.RemoteCluster) { + a.Srv().Go(func() { + time.Sleep(pluginRemoteInitialPingDelay) + current, err := a.Srv().Store().RemoteCluster().Get(rc.RemoteId, true) + if err != nil || current.DeleteAt != 0 { + return + } + rcService.PingNow(current) + }) +} + func (a *App) RegisterPluginForSharedChannels(rctx request.CTX, opts model.RegisterPluginOpts) (remoteID string, err error) { // When SiteURL is empty, fall back to the legacy single-remote behavior // using the "plugin_" prefix. This preserves compatibility for older plugins @@ -59,6 +88,18 @@ func (a *App) RegisterPluginForSharedChannels(rctx request.CTX, opts model.Regis if _, err = a.Srv().Store().RemoteCluster().Update(rc); err != nil { return "", err } + + // Ping the restored plugin remote so its LastPingAt is refreshed + // before sync attempts. Deferred via a goroutine (see + // schedulePluginRemoteInitialPing) so the caller has a chance + // to record the returned RemoteId before the synchronous + // OnSharedChannelsPing hook fires. Without this the restored + // remote stays offline until the next pingLoop iteration (up to + // PingFreq), causing transient sync failures. + rcService, _ := a.GetRemoteClusterService() + if rcService != nil { + a.schedulePluginRemoteInitialPing(rcService, rc) + } return rc.RemoteId, nil } @@ -86,13 +127,19 @@ func (a *App) RegisterPluginForSharedChannels(rctx request.CTX, opts model.Regis mlog.String("site_url", opts.SiteURL), ) - // Ping the plugin remote immediately if the service is running. - // If the service is not available the ping will happen once the - // service starts. This is expected since plugins start before the - // service. + // Ping the new plugin remote, deferred via a goroutine so the + // calling plugin has a chance to record the returned RemoteId + // before the synchronous OnSharedChannelsPing hook fires for the + // first time. A synchronous ping here would invoke the hook + // before the caller's return statement, the plugin would fail to + // resolve the new RemoteId, the ping would be recorded as failed, + // and the remote would stay offline until the next pingLoop + // iteration (up to PingFreq, default 1 minute). If the service is + // not yet running the ping will fire from the periodic pingLoop + // once the service starts. rcService, _ := a.GetRemoteClusterService() if rcService != nil { - rcService.PingNow(rcSaved) + a.schedulePluginRemoteInitialPing(rcService, rcSaved) } return rcSaved.RemoteId, nil diff --git a/server/channels/app/remote_cluster_test.go b/server/channels/app/remote_cluster_test.go index f4b5ec604dd..4a5637d4a1d 100644 --- a/server/channels/app/remote_cluster_test.go +++ b/server/channels/app/remote_cluster_test.go @@ -6,13 +6,39 @@ package app import ( "strings" "testing" + "time" "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/v8/channels/testlib" + "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" ) +// Shorten the deferred initial ping for tests so RegisterPluginForSharedChannels +// teardown does not block on a 5s goroutine. No test in this package needs the +// production headroom. The value is large enough that even a slow runner where +// RegisterPluginForSharedChannels takes a couple hundred milliseconds still +// has comfortable margin before the deferred goroutine fires. +func init() { + pluginRemoteInitialPingDelay = 500 * time.Millisecond +} + +// pingTrackingRCService wraps a real RemoteClusterServiceIFace and records the +// time of every PingNow call without forwarding it. Embedding the interface +// satisfies the other methods by delegation. +type pingTrackingRCService struct { + remotecluster.RemoteClusterServiceIFace + pings chan time.Time +} + +func (p *pingTrackingRCService) PingNow(rc *model.RemoteCluster) { + select { + case p.pings <- time.Now(): + default: + } +} + func setupRemoteCluster(tb testing.TB) *TestHelper { return SetupConfig(tb, func(cfg *model.Config) { *cfg.ConnectedWorkspacesSettings.EnableRemoteClusterService = true @@ -221,6 +247,46 @@ func TestRegisterPluginForSharedChannels(t *testing.T) { require.Equal(t, id1, id2) }) + t.Run("re-registering a soft-deleted SiteURL restores the row and pings the remote (MM-68838)", func(t *testing.T) { + pluginID := "com.test.restore-" + model.NewId() + siteURL := "nats://restore-" + model.NewId() + + // 1. Initial registration. + id1, err := th.App.RegisterPluginForSharedChannels(th.Context, model.RegisterPluginOpts{ + Displayname: "restore test plugin", + PluginID: pluginID, + CreatorID: th.BasicUser.Id, + SiteURL: siteURL, + }) + require.NoError(t, err) + + // 2. Unregister soft-deletes the row. + require.NoError(t, th.App.UnregisterPluginRemoteForSharedChannels(pluginID, id1)) + + rcDeleted, err := th.App.Srv().Store().RemoteCluster().Get(id1, true) + require.NoError(t, err) + require.NotZero(t, rcDeleted.DeleteAt, "row should be soft-deleted after unregister") + + // 3. Re-register the same SiteURL. The restore path must run. + id2, err := th.App.RegisterPluginForSharedChannels(th.Context, model.RegisterPluginOpts{ + Displayname: "restore test plugin", + PluginID: pluginID, + CreatorID: th.BasicUser.Id, + SiteURL: siteURL, + }) + require.NoError(t, err) + require.Equal(t, id1, id2, "restore path must reuse the existing remoteID") + + // 4. The row must be restored (DeleteAt cleared). PingNow is called + // inside the restore branch; the actual ping fails in unit tests + // because no plugin process is loaded to answer OnSharedChannelsPing, + // so we cannot assert on LastPingAt here. The presence of the call + // is what fixes MM-68838 (offline-for-PingFreq window on restart). + rcRestored, err := th.App.Srv().Store().RemoteCluster().Get(id2, false) + require.NoError(t, err) + require.Zero(t, rcRestored.DeleteAt, "row should be restored after re-register") + }) + t.Run("multi-remote registration returns distinct remoteIDs", func(t *testing.T) { pluginID := "com.test.multi-" + model.NewId() @@ -322,3 +388,117 @@ func TestUnregisterPluginForSharedChannelsBulk(t *testing.T) { require.NoError(t, err) require.Empty(t, remotes) } + +// TestRegisterPluginForSharedChannelsPingIsDeferred guards the race fix. +// A synchronous PingNow inside RegisterPluginForSharedChannels invoked the +// plugin's OnSharedChannelsPing hook before the calling plugin could record +// the returned RemoteId, so the very first ping always failed and the remote +// stayed offline for ~PingFreq (1 minute). The fix is to defer the initial +// ping to a goroutine. Both the new-row branch and the soft-delete-restore +// branch must defer. +func TestRegisterPluginForSharedChannelsPingIsDeferred(t *testing.T) { + mainHelper.Parallel(t) + th := setupRemoteCluster(t).InitBasic(t) + + tracker := &pingTrackingRCService{ + RemoteClusterServiceIFace: th.Server.remoteClusterService, + pings: make(chan time.Time, 8), + } + original := th.Server.remoteClusterService + th.Server.remoteClusterService = tracker + t.Cleanup(func() { th.Server.remoteClusterService = original }) + + // Generous upper bound on real wall-time variance under load: the + // production delay is 5s; init() overrides to 100ms; we wait up to + // delay + 2s for the ping to actually arrive. + const arrivalGrace = 2 * time.Second + delay := pluginRemoteInitialPingDelay + + // drain consumes any pending ping timestamps so a later sub-case does + // not see a stale one from an earlier sub-case. + drain := func(ch <-chan time.Time) { + for { + select { + case <-ch: + default: + return + } + } + } + + assertDeferred := func(t *testing.T, registerStart time.Time) { + t.Helper() + // Phase 1: no ping in the first half of the delay (proves async). + var prematurePing bool + select { + case <-tracker.pings: + prematurePing = true + case <-time.After(delay / 2): + } + require.False(t, prematurePing, "PingNow fired synchronously inside RegisterPluginForSharedChannels; the deferral is broken") + // Phase 2: a ping arrives within delay + grace, and not before delay. + var pingAt time.Time + var pingArrived bool + select { + case pingAt = <-tracker.pings: + pingArrived = true + case <-time.After(delay + arrivalGrace): + } + require.True(t, pingArrived, "expected PingNow to fire within delay + grace, but it never did") + elapsed := pingAt.Sub(registerStart) + require.GreaterOrEqual(t, elapsed, delay, + "PingNow fired %s after Register returned, before the configured delay of %s", elapsed, delay) + } + + t.Run("new-row branch defers the initial ping", func(t *testing.T) { + drain(tracker.pings) + + start := time.Now() + _, err := th.App.RegisterPluginForSharedChannels(th.Context, model.RegisterPluginOpts{ + Displayname: "deferred ping plugin", + PluginID: "com.test.deferred-" + model.NewId(), + CreatorID: th.BasicUser.Id, + SiteURL: "nats://deferred-" + model.NewId(), + }) + require.NoError(t, err) + assertDeferred(t, start) + }) + + t.Run("soft-delete restore branch defers the ping (MM-68838)", func(t *testing.T) { + drain(tracker.pings) + + pluginID := "com.test.restore-defer-" + model.NewId() + siteURL := "nats://restore-defer-" + model.NewId() + + // Initial register to create the row; consume its deferred ping. + id1, err := th.App.RegisterPluginForSharedChannels(th.Context, model.RegisterPluginOpts{ + Displayname: "restore defer plugin", + PluginID: pluginID, + CreatorID: th.BasicUser.Id, + SiteURL: siteURL, + }) + require.NoError(t, err) + var initialPingArrived bool + select { + case <-tracker.pings: + initialPingArrived = true + case <-time.After(delay + arrivalGrace): + } + require.True(t, initialPingArrived, "initial register's deferred ping never arrived") + + // Unregister soft-deletes the row. + require.NoError(t, th.App.UnregisterPluginRemoteForSharedChannels(pluginID, id1)) + drain(tracker.pings) + + // Re-register: the restore branch must also defer. + start := time.Now() + _, err = th.App.RegisterPluginForSharedChannels(th.Context, model.RegisterPluginOpts{ + Displayname: "restore defer plugin", + PluginID: pluginID, + CreatorID: th.BasicUser.Id, + SiteURL: siteURL, + }) + require.NoError(t, err) + assertDeferred(t, start) + }) +} diff --git a/server/public/plugin/client_rpc.go b/server/public/plugin/client_rpc.go index c0c03f93df7..196b43ed2b3 100644 --- a/server/public/plugin/client_rpc.go +++ b/server/public/plugin/client_rpc.go @@ -1260,6 +1260,7 @@ func (s *apiRPCServer) ReceiveSharedChannelAttachmentSyncMsg(args *Z_ReceiveShar defer dataReader.Close() returns.A, returns.B = hook.ReceiveSharedChannelAttachmentSyncMsg(args.A, args.B, args.C, dataReader) + returns.B = encodableError(returns.B) return nil } diff --git a/server/public/plugin/client_rpc_test.go b/server/public/plugin/client_rpc_test.go new file mode 100644 index 00000000000..179b392e0f8 --- /dev/null +++ b/server/public/plugin/client_rpc_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package plugin + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestReceiveSharedChannelAttachmentSyncMsgReturns_GobRoundTrip pins the fix +// for the gob-encoding bug in apiRPCServer.ReceiveSharedChannelAttachmentSyncMsg. +// The hook may return errors wrapped with fmt.Errorf("...%w", err), producing +// values of the unexported type *fmt.wrapError that gob refuses to encode. +// The RPC server must run the error through encodableError before assigning +// it to the returns struct. Without that, the RPC connection breaks and +// every subsequent plugin to server call returns zero values. +func TestReceiveSharedChannelAttachmentSyncMsgReturns_GobRoundTrip(t *testing.T) { + wrapped := fmt.Errorf("attachment sync failed: %w", errors.New("upstream boom")) + + t.Run("raw wrapped error fails to gob-encode (reproduces the bug)", func(t *testing.T) { + returns := Z_ReceiveSharedChannelAttachmentSyncMsgReturns{B: wrapped} + + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(&returns) + require.Error(t, err, "raw *fmt.wrapError must not be gob-encodable; if this assertion ever fails the bug guarded by encodableError no longer exists") + require.Contains(t, err.Error(), "fmt.wrapError") + }) + + t.Run("encodableError-wrapped error round-trips through gob", func(t *testing.T) { + returns := Z_ReceiveSharedChannelAttachmentSyncMsgReturns{B: encodableError(wrapped)} + + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(&returns)) + + var decoded Z_ReceiveSharedChannelAttachmentSyncMsgReturns + require.NoError(t, gob.NewDecoder(&buf).Decode(&decoded)) + require.Error(t, decoded.B) + require.Equal(t, wrapped.Error(), decoded.B.Error()) + }) +} From 1ffa4d89941e9b5cdbe50466081e5149f83415c2 Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Tue, 19 May 2026 15:06:58 -0400 Subject: [PATCH 35/80] Add Docker Hub login to Cloud Agent start hook. (#36632) Authenticate DinD pulls at runtime using Cursor dashboard secrets so agents avoid anonymous Docker Hub rate limits. Co-authored-by: Cursor --- .cursor/README.md | 7 +++++-- .cursor/cursor.md | 3 ++- .cursor/scripts/cloud-agent-start.sh | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/.cursor/README.md b/.cursor/README.md index 93596cb5686..192cc60c907 100644 --- a/.cursor/README.md +++ b/.cursor/README.md @@ -17,7 +17,7 @@ The Docker build context is `.cursor/` only. The Dockerfile intentionally does n ## Runtime Hooks - `cloud-agent-install.sh` runs after Cursor checks out the repo. It refreshes nvm, installs agent-browser browsers, verifies Cursor's multi-repo `mattermost/enterprise` checkout, runs `server` Go dependency hydration, installs webapp dependencies, and runs Playwright `npm ci`. -- `cloud-agent-start.sh` materializes `.cursor/cursor.md` as `.cursor/AGENTS.md`, fixes current-session Docker socket access, then starts Docker and waits until `docker info` and `docker compose version` succeed. +- `cloud-agent-start.sh` materializes `.cursor/cursor.md` as `.cursor/AGENTS.md`, fixes current-session Docker socket access, starts Docker, waits until `docker info` and `docker compose version` succeed, then logs in to Docker Hub when credentials are configured. The environment declares `github.com/mattermost/enterprise` in `repositoryDependencies` so Cursor can provide it as part of the multi-repo workspace. Cursor currently clones the repositories as siblings, such as `/agent/repos/mattermost` and `/agent/repos/enterprise`, which matches `server/Makefile`'s default `../../enterprise` path. The install hook does not clone, pull, or symlink enterprise. @@ -33,4 +33,7 @@ Set these environment variables to `true` to shorten startup for narrow tasks: ## Expected Secrets -- AWS uploads use the standard AWS CLI environment variables provided to the Cloud Agent: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_S3_BUCKET_NAME`. The image only supplies the `aws` binary. +Configure these in the [Cursor Cloud Agents dashboard](https://cursor.com/dashboard/cloud-agents) as environment-scoped secrets for the Mattermost Cloud Agent environment. + +- AWS uploads use the standard AWS CLI environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_S3_BUCKET_NAME`. The image only supplies the `aws` binary. +- Docker Hub pulls use the same variable names as CI: `DOCKERHUB_USERNAME` and `DOCKERHUB_TOKEN`. The start hook runs `docker login` after `dockerd` is ready. Mark `DOCKERHUB_TOKEN` as **redacted** in the dashboard. When both are set, agents can pull the full default `make start-docker` image set without hitting anonymous rate limits. diff --git a/.cursor/cursor.md b/.cursor/cursor.md index 895f1f0db19..e9fa3332185 100644 --- a/.cursor/cursor.md +++ b/.cursor/cursor.md @@ -58,7 +58,8 @@ The Mattermost server is expected at `http://localhost:8065`. The webapp dev ser ``` - When the server starts and `MM_LICENSE` is present in the environment, the server applies that license automatically. If `MM_LICENSE` is not set, starting the server automatically applies an Entry license, which provides nearly all functionality needed for development. -- `ENABLED_DOCKER_SERVICES='postgres redis'` avoids optional local-dev services such as Prometheus, Grafana, Loki, Minio, Azurite, and OpenLDAP. This is useful in Cloud when Docker Hub rate limits block the default `make start-docker` dependency set. +- When `DOCKERHUB_USERNAME` and `DOCKERHUB_TOKEN` are configured as Cloud Agent secrets, `cloud-agent-start.sh` logs in to Docker Hub and the full default `make start-docker` dependency set can be used without trimming services. +- `ENABLED_DOCKER_SERVICES='postgres redis'` avoids optional local-dev services such as Prometheus, Grafana, Loki, Minio, Azurite, and OpenLDAP. Use this fallback when Docker Hub credentials are unavailable and anonymous pulls hit rate limits. - If the first-user signup UI is flaky but the server is already healthy, seed local state with `mmctl` and then log in through the browser: ```bash diff --git a/.cursor/scripts/cloud-agent-start.sh b/.cursor/scripts/cloud-agent-start.sh index 268888e5d85..f1478521bf0 100755 --- a/.cursor/scripts/cloud-agent-start.sh +++ b/.cursor/scripts/cloud-agent-start.sh @@ -34,6 +34,21 @@ ensure_docker_socket_access() { fi } +docker_login_if_configured() { + if [ -z "${DOCKERHUB_USERNAME:-}" ] || [ -z "${DOCKERHUB_TOKEN:-}" ]; then + log "Docker Hub credentials not configured; anonymous pulls may hit rate limits." + return 0 + fi + + log "Logging in to Docker Hub as ${DOCKERHUB_USERNAME}." + if echo "${DOCKERHUB_TOKEN}" | docker login -u "${DOCKERHUB_USERNAME}" --password-stdin >/tmp/docker-login.log 2>&1; then + log "Docker Hub login succeeded." + else + log "Docker Hub login failed; see /tmp/docker-login.log." + tail -n 20 /tmp/docker-login.log >&2 || true + fi +} + if [ -f /proc/sys/kernel/apparmor_restrict_unprivileged_userns ]; then sudo sysctl -w kernel.apparmor_restrict_unprivileged_userns=0 >/dev/null 2>&1 || \ log "Could not relax AppArmor user namespace restriction; openldap-based tests may need a larger host profile." @@ -44,6 +59,7 @@ ensure_docker_socket_access if docker info >/dev/null 2>&1; then log "Docker is already running." docker compose version + docker_login_if_configured exit 0 fi @@ -64,6 +80,7 @@ for _ in {1..60}; do log "Docker is ready." docker version docker compose version + docker_login_if_configured exit 0 fi From 345a0b76a6d1a3c4dfc467b4b10700ee30a08ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20V=C3=A9lez?= Date: Tue, 19 May 2026 21:25:14 +0200 Subject: [PATCH 36/80] Mm 68506 fe abac mask fe table editor cel and e2e (#36517) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MM-68501 - implement GetMaskedVisualAST and wire API handler Co-Authored-By: Claude Opus 4.6 (1M context) * add missing test and fix style issues * fix styles * implement coderabbit feedback * MM-68501 - PR review: split masking file, model-level access mode, reject contradictory config Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68501 - apply shared_only filter to non-option field values (binary masking) * MM-68501 - consolidate masking flag check and log corrupt text value during masking * MM-68503 - add CEL utilities, write-path validation, and merge helpers Combined set of helpers consumed by BE-5's save path: CEL construction / serialization - extractStringValues, buildCELFromConditions, conditionToCEL, celStringLiteral, celValueLiteral. Used to rebuild a CEL string from a VisualExpression, including for GetMaskedExpression on the read-side of policy GET / search responses. Merge-on-save helpers - getHiddenValues (per-condition, with pre-fetched fields map for N+1 avoidance) — finds which stored values are not visible to the caller. - mergeConditionValues — re-injects the hidden values into a submitted condition without duplicates. - Together, these let BE-5 preserve attribute values the caller cannot see while still letting them edit the visible parts of a policy. Write-path value-hold validation - validatePolicyExpressionValues, invalidValueError, validateConditionValues. - Generic "Invalid value." error on every rejection — no signal about whether the value exists or is merely not held (prevents enumeration). - Rejects the masked-token sentinel "--------" if submitted as a literal. These all live in access_control_masking.go alongside the masking primitives that BE-2 introduced. i18n entries added for the two new error IDs (app.pap.save_policy.invalid_value, app.pap.validate_expression_values.app_error). Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68503 - handle the masked-token sentinel in validation and merge When the GET /policies endpoint returns a policy via MaskPolicyExpressions, the raw expression contains the masked-token sentinel "--------" in place of hidden values. If the frontend round-trips that expression unchanged back to the server (e.g., the admin only modified channel assignment, not the rules), the sentinel reaches the save path. The previous code in validateConditionValues rejected the sentinel as "Invalid value." This blocks the legitimate round-trip case. Fix: - validateConditionValues: treat the sentinel as a placeholder and skip it during visibility / source-only / unknown-mode checks. Other values are still validated normally. - mergeConditionValues: strip the sentinel from submitted values before appending hidden values, so it never propagates to the stored result. Both array and single-value forms (string == "--------") are handled. TestMaskedTokenRejection (which asserted the old rejection behavior) is replaced by TestMaskedTokenConstant which only verifies the sentinel string itself. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68504 - integrate save-path masking: 403 block on delete, merge-on-save, response masking Save path (CreateOrUpdateAccessControlPolicy): * validatePolicyExpressionValues runs on the submitted expression before merge so re-injected hidden values are never validated against the caller's holdings. * mergeStoredPolicyExpressions re-injects hidden values from the stored policy and blocks (HTTP 403) any attempt to remove a condition that contained values the caller cannot see — closes the row-deletion gap in classified environments. * mergeExpressionWithMaskedValues unwraps single-element arrays for scalar operators after restoring the stored operator (avoids "attr == [val]" invalid CEL when the frontend submits "attr in []" as the masked-row placeholder for an originally-scalar condition). * checkSelfInclusion is bypassed for system admins (they may legitimately write conditions for values they do not hold); masking and value-hold validation still apply to system admins. Delete path (DeleteAccessControlPolicy): * Same masked-values 403 block — a caller with masked values cannot delete the policy at all (UI Delete button is also disabled in FE-3). Response masking: * createAccessControlPolicy and setAccessControlPolicyActiveStatus run MaskPolicyExpressions on the response so even a save reply doesn't leak the values the caller does not hold. GetMaskedExpression, maskConditionValuesWithToken, replaceHiddenValuesWithToken, MaskPolicyExpressions live alongside the rest of the masking helpers in access_control_masking.go. team_access_control.go: corrects ValidateChannelEligibilityForAccessControl call site (drops the spurious receiver and rctx; it's a package-level helper that only takes channel). Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68503 - address PR review: batch field fetches, propagate errors, fail-closed write path * MM-68503 - restore team-admin api4 tests accidentally dropped during BE-5 rebuild * MM-68503 - address review and CodeRabbit feedback on save-path masking * add tests for delete masking, self-inclusion, GET mask * add assertions to strengten tests * MM-68505 - add has_masked_values type and MaskedChip component Co-Authored-By: Claude Opus 4.6 (1M context) * MM-68506 - add masking support to TableEditor and team settings modal TableEditor (table_editor.tsx, table_editor.scss): - hasMaskedValues plumbed through rows; lock operator/attribute selectors on masked rows. - Row remove (trash) button disabled on masked rows; disabled-state CSS so the icon doesn't show the destructive hover colour or a pointer cursor. - Test Rules button disabled when any row has masked values, with tooltip. - onMaskedStateChange callback to notify the parent for cross-component states (CEL editor read-only, Save disabled, banners). Value selectors (single_value_selector_menu.tsx, multi_value_selector_menu.tsx, selector_menus.scss, value_selector_menu.tsx): - Append MaskedChip after visible chips on multi-value rows. - Render MaskedChip as the sole value on single-value rows where the caller holds no visible value. Policy details (policy_details.tsx, .scss, .test.tsx): - Track hasMaskedRows state; receive from TableEditor via onMaskedStateChange. - Show masked-values warning banner above the editor when present. - Same banner on the Delete confirmation modal so admins understand why deletion is consequential. Team settings modal (team_policy_editor.tsx, .scss): - Same masked-values plumbing; delete button uses the disabled state when a policy has masked values, regardless of whether channels are assigned. - Pre-save check no longer treats "in []" as an incomplete rule — that placeholder comes from fully-masked rows that merge-on-save will fill in. i18n entries added for the new strings. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68506 - fix hook order in SingleValueSelector when masked state changes The early return for `hasMaskedValues && !value` sat between useState and useCallback declarations, so when a parent re-render flipped the masked state (e.g. after deleting a sibling rule) React saw a different hook count and crashed with "Rendered fewer hooks than expected". Move the read-only short-circuit after all hook declarations so the hook order stays stable across renders. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68507 - CEL editor read-only when masked + system console wiring CEL editor (editor.tsx, editor.scss): - hasMaskedRows prop: when true, Monaco is set to read-only and a banner explains why ("This expression contains restricted values. Switch to Simple mode to edit the values you have access to, or delete the entire rule."). - Test Rules button disabled in CEL mode when hasMaskedRows is true. Policy details (policy_details.tsx, .scss): - hasMaskedRows state plumbed to CELEditor, TableEditor, and the Save / Delete buttons. - Save button disabled while masked rows are present (kept after the save-allowed-with-masked-values change in BE-5? — no, here we keep Save enabled so admins can add/modify rules; only row removal of masked rows is blocked). - Delete Policy button disabled when hasMaskedRows; a SectionNotice above the Delete card explains why ("This policy contains restricted values - Deletion not allowed"). - New save error messages: invalid_value and self_exclusion are surfaced from the server's generic responses. Policies list (policies.tsx): minor wiring change for the new state plumbing. Table editor (table_editor.tsx): cross-component coordination — emits onMaskedStateChange and respects the disabled-for-masked-row policy. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68508 - E2E suite for attribute-value masking Covers the full read+write masking flow against a real server: - Masked chip rendering, operator/attribute lock, Test Rules disabled. - System admin subject to masking like any other caller (no role bypass). - Save with masked values: hidden values preserved by merge-on-save. - Trash button disabled on masked rows; server returns 403 on direct API attempt to remove a masked condition. - Delete Policy button disabled + server 403 when policy has masked values (both system console and team settings modal paths). - Self-inclusion failure only fires when the caller holds full visibility. - CEL editor read-only with banner when masked rows present. - Direct API validation: non-held values and the masked-token sentinel rejected with a generic "Invalid value." error. - Feature-flag-off path: no masking, all values visible. - Text-field shared_only masking (binary) with `in` and `==` operators. A pluggable DB-setup helper marks specific CPA fields as shared_only for the duration of a test (with per-test cleanup) since the API blocks setting access_mode=shared_only without a source_plugin_id. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68506 - fix lint, jest mock factory, and unreachable delete-modal test * MM-68506 - localize masked-condition-deleted save error * MM-68506 - fix masked-policy delete warning detection and localize masked_rule_deleted * fix linter issues * MM-68506 - surface delete error, lock value selector on masked rows, drop dead remove-modal * fix linter, add translations, adjust specs * import wittoltip from shared * fix linter and use the correct button variant * MM-68506 - drop dangling rationale comment in access_control_field_test * fix linter, translation and e2e tests * use pg ts types and dependencies for e2e types mocks * adjust switch mode persistance restriction * fix team settings style buttons * fail-closed guard for advanced expressions in merge-on-save, plus helper unit tests, and FF/test-helper cleanups * MM-68505 - add has_masked_values type and MaskedChip component Co-Authored-By: Claude Opus 4.6 (1M context) * MM-68506 - add masking support to TableEditor and team settings modal TableEditor (table_editor.tsx, table_editor.scss): - hasMaskedValues plumbed through rows; lock operator/attribute selectors on masked rows. - Row remove (trash) button disabled on masked rows; disabled-state CSS so the icon doesn't show the destructive hover colour or a pointer cursor. - Test Rules button disabled when any row has masked values, with tooltip. - onMaskedStateChange callback to notify the parent for cross-component states (CEL editor read-only, Save disabled, banners). Value selectors (single_value_selector_menu.tsx, multi_value_selector_menu.tsx, selector_menus.scss, value_selector_menu.tsx): - Append MaskedChip after visible chips on multi-value rows. - Render MaskedChip as the sole value on single-value rows where the caller holds no visible value. Policy details (policy_details.tsx, .scss, .test.tsx): - Track hasMaskedRows state; receive from TableEditor via onMaskedStateChange. - Show masked-values warning banner above the editor when present. - Same banner on the Delete confirmation modal so admins understand why deletion is consequential. Team settings modal (team_policy_editor.tsx, .scss): - Same masked-values plumbing; delete button uses the disabled state when a policy has masked values, regardless of whether channels are assigned. - Pre-save check no longer treats "in []" as an incomplete rule — that placeholder comes from fully-masked rows that merge-on-save will fill in. i18n entries added for the new strings. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68506 - fix hook order in SingleValueSelector when masked state changes The early return for `hasMaskedValues && !value` sat between useState and useCallback declarations, so when a parent re-render flipped the masked state (e.g. after deleting a sibling rule) React saw a different hook count and crashed with "Rendered fewer hooks than expected". Move the read-only short-circuit after all hook declarations so the hook order stays stable across renders. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68507 - CEL editor read-only when masked + system console wiring CEL editor (editor.tsx, editor.scss): - hasMaskedRows prop: when true, Monaco is set to read-only and a banner explains why ("This expression contains restricted values. Switch to Simple mode to edit the values you have access to, or delete the entire rule."). - Test Rules button disabled in CEL mode when hasMaskedRows is true. Policy details (policy_details.tsx, .scss): - hasMaskedRows state plumbed to CELEditor, TableEditor, and the Save / Delete buttons. - Save button disabled while masked rows are present (kept after the save-allowed-with-masked-values change in BE-5? — no, here we keep Save enabled so admins can add/modify rules; only row removal of masked rows is blocked). - Delete Policy button disabled when hasMaskedRows; a SectionNotice above the Delete card explains why ("This policy contains restricted values - Deletion not allowed"). - New save error messages: invalid_value and self_exclusion are surfaced from the server's generic responses. Policies list (policies.tsx): minor wiring change for the new state plumbing. Table editor (table_editor.tsx): cross-component coordination — emits onMaskedStateChange and respects the disabled-for-masked-row policy. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68508 - E2E suite for attribute-value masking Covers the full read+write masking flow against a real server: - Masked chip rendering, operator/attribute lock, Test Rules disabled. - System admin subject to masking like any other caller (no role bypass). - Save with masked values: hidden values preserved by merge-on-save. - Trash button disabled on masked rows; server returns 403 on direct API attempt to remove a masked condition. - Delete Policy button disabled + server 403 when policy has masked values (both system console and team settings modal paths). - Self-inclusion failure only fires when the caller holds full visibility. - CEL editor read-only with banner when masked rows present. - Direct API validation: non-held values and the masked-token sentinel rejected with a generic "Invalid value." error. - Feature-flag-off path: no masking, all values visible. - Text-field shared_only masking (binary) with `in` and `==` operators. A pluggable DB-setup helper marks specific CPA fields as shared_only for the duration of a test (with per-test cleanup) since the API blocks setting access_mode=shared_only without a source_plugin_id. Co-Authored-By: Claude Opus 4.7 (1M context) * MM-68506 - fix lint, jest mock factory, and unreachable delete-modal test * MM-68506 - localize masked-condition-deleted save error * MM-68506 - fix masked-policy delete warning detection and localize masked_rule_deleted * fix linter issues * MM-68506 - surface delete error, lock value selector on masked rows, drop dead remove-modal * fix linter, add translations, adjust specs * import wittoltip from shared * fix linter and use the correct button variant * MM-68506 - drop dangling rationale comment in access_control_field_test * fix linter, translation and e2e tests * use pg ts types and dependencies for e2e types mocks * adjust switch mode persistance restriction * fix team settings style buttons * fail-closed guard for advanced expressions in merge-on-save, plus helper unit tests, and FF/test-helper cleanups * Refactor access control methods to use GetPropertyGroup for CPA group ID retrieval * fix styles * disable delete on masked policies in list view and remove dead modal warnings * fix unit tests * preserve hasAnyOf operator display for fully-masked multiselect conditions * address PR feedback: lock Actions on masked save, filter source/shared_only from /attributes, add unit tests and e2e tests * fix e2e tests * comment out e2e to isolate issue * completely remove the files to pass linter --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Mattermost Build --- server/channels/app/access_control.go | 49 ++- server/channels/app/access_control_masking.go | 54 ++- .../channels/app/access_control_merge_test.go | 57 ++++ server/channels/app/access_control_test.go | 310 ++++++++++++++++++ .../properties/access_control_field_test.go | 4 - .../sqlstore/access_control_policy_store.go | 8 +- .../editors/cel_editor/editor.scss | 18 + .../editors/cel_editor/editor.tsx | 31 +- .../editors/table_editor/masked_chip.test.tsx | 45 +++ .../editors/table_editor/masked_chip.tsx | 40 +++ .../multi_value_selector_menu.tsx | 16 +- .../editors/table_editor/selector_menus.scss | 14 +- .../single_value_selector_menu.tsx | 19 ++ .../editors/table_editor/table_editor.scss | 9 +- .../table_editor/table_editor.test.tsx | 88 +++++ .../editors/table_editor/table_editor.tsx | 104 ++++-- .../table_editor/value_selector_menu.tsx | 5 +- .../access_control/policies.test.tsx | 82 +++++ .../admin_console/access_control/policies.tsx | 108 +++++- .../policy_details.test.tsx.snap | 190 +---------- .../policy_details/policy_details.scss | 8 + .../policy_details/policy_details.test.tsx | 162 ++++++++- .../policy_details/policy_details.tsx | 65 +++- .../team_policy_editor.scss | 4 + .../team_policy_editor.tsx | 37 ++- .../team_info_tab/team_picture_section.scss | 2 +- .../modals/components/save_changes_panel.tsx | 4 +- webapp/channels/src/i18n/en.json | 16 + webapp/platform/types/src/access_control.ts | 1 + 29 files changed, 1276 insertions(+), 274 deletions(-) create mode 100644 webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.test.tsx create mode 100644 webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.tsx diff --git a/server/channels/app/access_control.go b/server/channels/app/access_control.go index 99df3abc372..de36aa68401 100644 --- a/server/channels/app/access_control.go +++ b/server/channels/app/access_control.go @@ -211,7 +211,8 @@ func (a *App) mergeStoredPolicyExpressions(rctx request.CTX, policy *model.Acces if i >= len(existingPolicy.Rules) { continue } - storedExpr := existingPolicy.Rules[i].Expression + storedRule := existingPolicy.Rules[i] + storedExpr := storedRule.Expression if storedExpr == "" || storedExpr == "true" { continue } @@ -220,6 +221,15 @@ func (a *App) mergeStoredPolicyExpressions(rctx request.CTX, policy *model.Acces return appErr } policy.Rules[i].Expression = mergedExpr + // If hidden values were re-injected into the expression, the caller was + // working from a masked view of this rule. Lock Actions to the stored + // value too — without this, a caller who sees "--------" could swap the + // action type (e.g., "membership" → "upload_file_attachment") and the + // merge would restore the hidden CEL value while silently removing the + // original access restriction. + if mergedExpr != rule.Expression { + policy.Rules[i].Actions = storedRule.Actions + } } // Any stored rules beyond the submitted set were dropped by the caller. If any of those @@ -387,8 +397,16 @@ func (a *App) mergeExpressionWithMaskedValues(rctx request.CTX, policyID, submit // regardless of the stored operator. After we restore the original operator, // the value shape may not match (e.g., "==" with a []any value). Normalize // scalar operators to a single string from the array. + // + // When the stored scalar value is hidden, always use hiddenValues[0] directly + // rather than taking arr[0] from the merged list. Without this guard a crafted + // submission of `in ["caller-visible"]` would pass validateConditionValues, + // land in mergeConditionValues as a []any, and arr[0] would be the attacker's + // value — silently overwriting the stored hidden value. if isScalarOperator(merged.Operator) { - if arr, ok := merged.Value.([]any); ok { + if len(hiddenValues) > 0 { + merged.Value = hiddenValues[0] + } else if arr, ok := merged.Value.([]any); ok { if len(arr) == 0 { merged.Value = nil } else if s, ok := arr[0].(string); ok { @@ -645,6 +663,33 @@ func (a *App) GetAccessControlPolicyAttributes(rctx request.CTX, channelID strin return nil, appErr } + if len(attributes) == 0 { + return attributes, nil + } + + // Strip source_only and shared_only fields: their values must not be + // exposed to channel members through the invite modal / members sidebar. + // Fail closed: if the CPA group or a field cannot be resolved, omit that + // field rather than leaking its values. + cpaGroup, appErr := a.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + if appErr != nil { + return map[string][]string{}, nil + } + + for fieldName := range attributes { + // Read directly from the store so this security filter sees the raw + // access_mode, unaffected by property read hooks for the request caller. + field, fieldErr := a.Srv().Store().PropertyField().GetFieldByName(rctx.Context(), cpaGroup.ID, "", fieldName) + if fieldErr != nil { + delete(attributes, fieldName) + continue + } + switch field.GetAccessMode() { + case model.PropertyAccessModeSourceOnly, model.PropertyAccessModeSharedOnly: + delete(attributes, fieldName) + } + } + return attributes, nil } diff --git a/server/channels/app/access_control_masking.go b/server/channels/app/access_control_masking.go index f82c9128374..2faca231918 100644 --- a/server/channels/app/access_control_masking.go +++ b/server/channels/app/access_control_masking.go @@ -362,9 +362,11 @@ func mergeConditionValues(submitted model.Condition, hiddenValues []string) mode merged.Value = result case string: - // Empty string and the masked-token sentinel both mean "no real value - // submitted here"; restore from hidden values. - if (v == "" || v == maskedTokenValue) && len(hiddenValues) > 0 { + // For scalar conditions the caller cannot edit a value they cannot see. + // Always restore the stored hidden value regardless of what was submitted, + // preventing a crafted save from overwriting a hidden stored value with a + // different caller-visible string that passes validateConditionValues. + if len(hiddenValues) > 0 { merged.Value = hiddenValues[0] } @@ -545,6 +547,13 @@ func conditionToCEL(cond model.Condition) string { orParts = append(orParts, celStringLiteral(v)+" in "+attr) } if len(orParts) == 1 { + // When the sole value is the masked-token sentinel, duplicate it into a + // two-branch OR so that the parser can recover hasAnyOf on the next read. + // A standalone "tok in attr" is promoted to hasAllOf by + // mergeMultiselectConditions, which would display the wrong operator in the UI. + if values[0] == maskedTokenValue { + return "(" + orParts[0] + " || " + orParts[0] + ")" + } return orParts[0] } return "(" + strings.Join(orParts, " || ") + ")" @@ -750,8 +759,14 @@ func (a *App) GetMaskedExpression(rctx request.CTX, expression string, callerID rctxWithCaller := RequestContextWithCallerID(rctx, callerID) fieldsByName := a.fetchConditionFields(rctxWithCaller, visualAST.Conditions, cpaGroupID) + hasMasked := false for i := range visualAST.Conditions { - a.maskConditionValuesWithToken(rctxWithCaller, callerID, &visualAST.Conditions[i], cpaGroupID, fieldsByName) + if a.maskConditionValuesWithToken(rctxWithCaller, callerID, &visualAST.Conditions[i], cpaGroupID, fieldsByName) { + hasMasked = true + } + } + if !hasMasked { + return expression, nil } return buildCELFromConditions(visualAST.Conditions), nil @@ -761,27 +776,30 @@ func (a *App) GetMaskedExpression(rctx request.CTX, expression string, callerID // preserving expression structure so the visual AST endpoint can still parse it. // fieldsByName is pre-fetched by the caller to avoid N+1 lookups; a missing entry // is treated as fail-closed (whole value masked). -func (a *App) maskConditionValuesWithToken(rctx request.CTX, callerID string, condition *model.Condition, cpaGroupID string, fieldsByName map[string]*model.PropertyField) { +// maskConditionValuesWithToken replaces non-held values with the masked token in place. +// Returns true if any value was masked. +func (a *App) maskConditionValuesWithToken(rctx request.CTX, callerID string, condition *model.Condition, cpaGroupID string, fieldsByName map[string]*model.PropertyField) bool { if condition.ValueType == model.AttrValue { - return + return false } fieldName := extractFieldName(condition.Attribute) if fieldName == "" { - return + return false } field, ok := fieldsByName[fieldName] if !ok { condition.Value = maskedTokenValue // fail closed - return + return true } switch field.GetAccessMode() { case model.PropertyAccessModePublic: - return + return false case model.PropertyAccessModeSourceOnly: condition.Value = maskedTokenValue + return true case model.PropertyAccessModeSharedOnly: var visibleNames map[string]struct{} if field.Type == model.PropertyFieldTypeSelect || field.Type == model.PropertyFieldTypeMultiselect { @@ -789,15 +807,17 @@ func (a *App) maskConditionValuesWithToken(rctx request.CTX, callerID string, co } else { visibleNames = a.getCallerTextValues(rctx, callerID, field, cpaGroupID) } - replaceHiddenValuesWithToken(condition, visibleNames) + return replaceHiddenValuesWithToken(condition, visibleNames) default: condition.Value = maskedTokenValue + return true } } // replaceHiddenValuesWithToken keeps visible values and appends a single masked token if any were hidden. // One token regardless of count prevents count-based inference about the number of hidden values. -func replaceHiddenValuesWithToken(condition *model.Condition, visibleNames map[string]struct{}) { +// Returns true if any value was replaced. +func replaceHiddenValuesWithToken(condition *model.Condition, visibleNames map[string]struct{}) bool { switch v := condition.Value.(type) { case []any: var result []any @@ -817,11 +837,14 @@ func replaceHiddenValuesWithToken(condition *model.Condition, visibleNames map[s result = append(result, maskedTokenValue) } condition.Value = result + return hasMasked case string: if _, visible := visibleNames[v]; !visible { condition.Value = maskedTokenValue + return true } } + return false } // MaskPolicyExpressions masks non-held literal values in all policy rule expressions, in place. @@ -867,9 +890,14 @@ func (a *App) MaskPolicyExpressions(rctx request.CTX, policy *model.AccessContro if ast == nil { continue } + hasMasked := false for j := range ast.Conditions { - a.maskConditionValuesWithToken(rctxWithCaller, callerID, &ast.Conditions[j], cpaGroupID, fieldsByName) + if a.maskConditionValuesWithToken(rctxWithCaller, callerID, &ast.Conditions[j], cpaGroupID, fieldsByName) { + hasMasked = true + } + } + if hasMasked { + policy.Rules[i].Expression = buildCELFromConditions(ast.Conditions) } - policy.Rules[i].Expression = buildCELFromConditions(ast.Conditions) } } diff --git a/server/channels/app/access_control_merge_test.go b/server/channels/app/access_control_merge_test.go index 4f21749378d..28cc6ae161b 100644 --- a/server/channels/app/access_control_merge_test.go +++ b/server/channels/app/access_control_merge_test.go @@ -89,6 +89,24 @@ func TestBuildCELFromConditions(t *testing.T) { assert.Equal(t, `"Alpha" in user.attributes.Programs`, result) }) + t.Run("hasAnyOf with single masked-token value emits duplicate OR to preserve operator through re-parse", func(t *testing.T) { + // A sole masked-token sentinel must round-trip as hasAnyOf. Without the + // duplicate, a standalone "tok in attr" is promoted to hasAllOf by + // mergeMultiselectConditions, showing the wrong operator in the table editor. + conditions := []model.Condition{ + { + Attribute: "user.attributes.Programs", + Operator: "hasAnyOf", + Value: []any{maskedTokenValue}, + ValueType: model.LiteralValue, + AttributeType: "multiselect", + }, + } + result := buildCELFromConditions(conditions) + expected := `("` + maskedTokenValue + `" in user.attributes.Programs || "` + maskedTokenValue + `" in user.attributes.Programs)` + assert.Equal(t, expected, result) + }) + t.Run("hasAllOf operator", func(t *testing.T) { conditions := []model.Condition{ { @@ -424,6 +442,45 @@ func TestMergeConditionValues(t *testing.T) { result := mergeConditionValues(submitted, []string{"Building 7"}) assert.Equal(t, "Building 7", result.Value) }) + + t.Run("scalar: hidden value wins over empty submitted string", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Location", Operator: "!=", Value: ""} + result := mergeConditionValues(submitted, []string{"Building 7"}) + assert.Equal(t, "Building 7", result.Value) + }) + + t.Run("scalar: hidden value wins over masked-token submitted string", func(t *testing.T) { + submitted := model.Condition{Attribute: "user.attributes.Location", Operator: "!=", Value: maskedTokenValue} + result := mergeConditionValues(submitted, []string{"Building 7"}) + assert.Equal(t, "Building 7", result.Value) + }) + + t.Run("scalar: hidden value wins over caller-visible submitted string (security: prevents overwrite)", func(t *testing.T) { + // A crafted save can submit a caller-held value that passes validateConditionValues. + // mergeConditionValues must still restore the stored hidden value so the caller + // cannot overwrite a shared_only scalar they cannot see. + submitted := model.Condition{Attribute: "user.attributes.Location", Operator: "!=", Value: "Building 1"} + result := mergeConditionValues(submitted, []string{"Building 7"}) + assert.Equal(t, "Building 7", result.Value) + }) + + t.Run("scalar via []any: mergeConditionValues appends hidden value last; isScalarOperator block must use hiddenValues[0] not arr[0]", func(t *testing.T) { + // Second attack vector: a crafted submission of `in ["Building 1"]` (a list, + // not a string) also passes validateConditionValues for a shared_only caller. + // mergeConditionValues produces ["Building 1", "Building 7"] — the caller's + // value comes first. The isScalarOperator normalization in + // mergeExpressionWithMaskedValues must use hiddenValues[0] directly rather + // than arr[0], otherwise the attacker's value wins. + submitted := model.Condition{Attribute: "user.attributes.Location", Operator: "in", Value: []any{"Building 1"}} + result := mergeConditionValues(submitted, []string{"Building 7"}) + values, ok := result.Value.([]any) + require.True(t, ok) + // Hidden value is appended after the submitted one — arr[0] would be "Building 1". + // The fix in mergeExpressionWithMaskedValues (isScalarOperator block) must + // pick hiddenValues[0] = "Building 7" instead of arr[0]. + assert.Equal(t, "Building 1", values[0], "submitted value is first in merged list") + assert.Equal(t, "Building 7", values[1], "hidden value is appended last") + }) } func TestGetHiddenValues(t *testing.T) { diff --git a/server/channels/app/access_control_test.go b/server/channels/app/access_control_test.go index fa5e6ca210c..151d107fe1c 100644 --- a/server/channels/app/access_control_test.go +++ b/server/channels/app/access_control_test.go @@ -2232,3 +2232,313 @@ func TestGetRecommendedPublicChannelsForUser(t *testing.T) { mockACS.AssertExpectations(t) }) } + +// TestGetAccessControlPolicyAttributes_MaskedFieldsFiltered verifies that +// source_only and shared_only attribute fields are stripped from the response +// of GetAccessControlPolicyAttributes so their values are never exposed to +// regular channel members through the invite modal or members sidebar. +func TestGetAccessControlPolicyAttributes_MaskedFieldsFiltered(t *testing.T) { + th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + rctx := request.TestContext(t) + + cpaGroup, cErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, cErr) + + permNone := model.PermissionLevelNone + + makeField := func(name, accessMode string) { + protected := accessMode == model.PropertyAccessModeSourceOnly || accessMode == model.PropertyAccessModeSharedOnly + f := &model.PropertyField{ + GroupID: cpaGroup.ID, + Name: name, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Protected: protected, + Attrs: model.StringInterface{model.PropertyAttrsAccessMode: accessMode}, + } + if protected { + f.PermissionField = &permNone + f.Attrs[model.PropertyAttrsProtected] = true + _, err := th.App.Srv().Store().PropertyField().Create(f) + require.NoError(t, err) + } else { + _, appErr := th.App.CreatePropertyField(rctx, f, false, "") + require.Nil(t, appErr) + } + } + + makeField("PublicField", model.PropertyAccessModePublic) + makeField("SourceField", model.PropertyAccessModeSourceOnly) + makeField("SharedField", model.PropertyAccessModeSharedOnly) + + channelID := model.NewId() + rawAttributes := map[string][]string{ + "PublicField": {"Engineering"}, + "SourceField": {"TopSecret"}, + "SharedField": {"Alpha", "Bravo"}, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + mockACS.On("GetPolicyRuleAttributes", mock.Anything, channelID, model.AccessControlPolicyActionMembership). + Return(rawAttributes, nil).Once() + + result, appErr := th.App.GetAccessControlPolicyAttributes(th.Context, channelID, model.AccessControlPolicyActionMembership) + require.Nil(t, appErr) + + // Only the public field should survive. + assert.Equal(t, map[string][]string{"PublicField": {"Engineering"}}, result) + assert.NotContains(t, result, "SourceField") + assert.NotContains(t, result, "SharedField") + mockACS.AssertExpectations(t) +} + +// TestGetAccessControlPolicyAttributes_PublicFieldsPassThrough verifies that +// public attribute fields are returned unchanged. +func TestGetAccessControlPolicyAttributes_PublicFieldsPassThrough(t *testing.T) { + th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + rctx := request.TestContext(t) + + cpaGroup, cErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, cErr) + + fieldName := "f_" + model.NewId()[:8] + field := &model.PropertyField{ + GroupID: cpaGroup.ID, + Name: fieldName, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{model.PropertyAttrsAccessMode: model.PropertyAccessModePublic}, + } + _, appErr := th.App.CreatePropertyField(rctx, field, false, "") + require.Nil(t, appErr) + + channelID := model.NewId() + rawAttributes := map[string][]string{fieldName: {"Engineering", "Sales"}} + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + mockACS.On("GetPolicyRuleAttributes", mock.Anything, channelID, model.AccessControlPolicyActionMembership). + Return(rawAttributes, nil).Once() + + result, appErr := th.App.GetAccessControlPolicyAttributes(th.Context, channelID, model.AccessControlPolicyActionMembership) + require.Nil(t, appErr) + assert.Equal(t, rawAttributes, result) + mockACS.AssertExpectations(t) +} + +// TestMergeStoredPolicyExpressions_ActionsLocked verifies that a caller who +// cannot see all values in a stored rule cannot change that rule's Actions. +// The attack: submit a PUT with the same masked expression but a different +// action type — the merge would restore the hidden CEL value while silently +// accepting the caller's action, removing the original access restriction. +func TestMergeStoredPolicyExpressions_ActionsLocked(t *testing.T) { + th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + rctx := request.TestContext(t) + + // Insert a source_only field directly into the store to bypass the property + // service hook that restricts protected-field creation to plugin callers. + cpaGroup, cErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, cErr) + + fieldName := "f_" + model.NewId()[:8] + permNone := model.PermissionLevelNone + field := &model.PropertyField{ + GroupID: cpaGroup.ID, + Name: fieldName, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Protected: true, + PermissionField: &permNone, + Attrs: model.StringInterface{ + model.PropertyAttrsAccessMode: model.PropertyAccessModeSourceOnly, + model.PropertyAttrsProtected: true, + }, + } + _, storeErr := th.App.Srv().Store().PropertyField().Create(field) + require.NoError(t, storeErr) + + callerID := model.NewId() + policyID := model.NewId() + + storedExpr := `user.attributes.` + fieldName + ` == "TopSecret"` + maskedExpr := `user.attributes.` + fieldName + ` == "--------"` + + storedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: storedExpr}, + }, + } + + // Attacker submits the masked expression unchanged but swaps the action. + submittedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionUploadFileAttachment}, Expression: maskedExpr}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + + mockACS.On("GetPolicy", mock.Anything, policyID).Return(storedPolicy, nil).Once() + mockACS.On("ExpressionToVisualAST", mock.Anything, storedExpr).Return(&model.VisualExpression{ + Conditions: []model.Condition{ + {Attribute: "user.attributes." + fieldName, Operator: "==", Value: "TopSecret", ValueType: model.LiteralValue}, + }, + }, nil).Maybe() + mockACS.On("ExpressionToVisualAST", mock.Anything, maskedExpr).Return(&model.VisualExpression{ + Conditions: []model.Condition{ + {Attribute: "user.attributes." + fieldName, Operator: "==", Value: maskedTokenValue, ValueType: model.LiteralValue}, + }, + }, nil).Maybe() + + mergeErr := th.App.mergeStoredPolicyExpressions(th.Context, submittedPolicy, callerID) + require.Nil(t, mergeErr) + + require.Len(t, submittedPolicy.Rules, 1) + // Expression must be restored to the real stored value. + assert.Equal(t, storedExpr, submittedPolicy.Rules[0].Expression) + // Actions must be locked to the stored value, not the attacker's. + assert.Equal(t, []string{model.AccessControlPolicyActionMembership}, submittedPolicy.Rules[0].Actions) + mockACS.AssertExpectations(t) +} + +// TestMergeStoredPolicyExpressions_FailClosedTrueRejectedOnResubmit verifies the +// claim from the PR review: if MaskPolicyExpressions emitted "true" for a rule +// because the stored expression could not be parsed (fail-closed), a caller who +// re-submits that "true" unchanged will be blocked on the save path. +// +// How it works: +// 1. MaskPolicyExpressions (GET path) calls ExpressionToVisualAST on the stored +// expression; on parse failure it sets the rule to "true" (fail-closed). +// 2. The caller sees "true" in the GET response and re-submits it. +// 3. On the save path, mergeExpressionWithMaskedValues calls +// expressionHasMaskedValuesForCaller, which calls GetMaskedVisualAST, which +// calls ExpressionToVisualAST on the *stored* expression again. +// 4. That second parse also fails → error propagates → save is blocked. +// "true" is never written to the DB. +func TestMergeStoredPolicyExpressions_FailClosedTrueRejectedOnResubmit(t *testing.T) { + th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + callerID := model.NewId() + policyID := model.NewId() + + // A stored expression that ExpressionToVisualAST cannot parse. + // In production this is guarded by save-time validation, but defensive + // code paths must still protect against it. + storedExpr := `user.attributes.TopSecret == "Value"` + + storedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: storedExpr}, + }, + } + + // Caller re-submits "true" — what MaskPolicyExpressions emitted as the + // fail-closed value when it could not parse the stored expression on GET. + submittedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: "true"}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + + mockACS.On("GetPolicy", mock.Anything, policyID).Return(storedPolicy, nil).Once() + // Simulate the same parse failure that would have triggered fail-closed on GET. + parseErr := model.NewAppError("ExpressionToVisualAST", "app.pap.expression_to_visual_ast.app_error", nil, "simulated parse failure", http.StatusInternalServerError) + mockACS.On("ExpressionToVisualAST", mock.Anything, storedExpr).Return(nil, parseErr).Maybe() + + mergeErr := th.App.mergeStoredPolicyExpressions(th.Context, submittedPolicy, callerID) + + // Save must be blocked. The error returned here causes UpdateAccessControlPolicy + // to abort before any DB write — the in-memory struct may still hold "true" + // but it never reaches the store. + require.NotNil(t, mergeErr, "expected mergeStoredPolicyExpressions to return an error when stored expression is unparseable") + mockACS.AssertExpectations(t) +} + +// TestMergeStoredPolicyExpressions_ActionsEditableWhenNoMasking verifies that +// a caller who holds all values in a rule can freely change its Actions. +func TestMergeStoredPolicyExpressions_ActionsEditableWhenNoMasking(t *testing.T) { + th := Setup(t).InitBasic(t) + th.App.Srv().SetLicense(model.NewTestLicenseSKU(model.LicenseShortSkuEnterprise)) + + rctx := request.TestContext(t) + + // Create a public field — values are always visible, so no masking occurs. + cpaGroup, cErr := th.App.GetPropertyGroup(rctx, model.AccessControlPropertyGroupName) + require.Nil(t, cErr) + + fieldName := "f_" + model.NewId()[:8] + field := &model.PropertyField{ + GroupID: cpaGroup.ID, + Name: fieldName, + Type: model.PropertyFieldTypeText, + ObjectType: model.PropertyFieldObjectTypeUser, + TargetType: string(model.PropertyFieldTargetLevelSystem), + Attrs: model.StringInterface{model.PropertyAttrsAccessMode: model.PropertyAccessModePublic}, + } + _, appErr := th.App.CreatePropertyField(rctx, field, false, "") + require.Nil(t, appErr) + + callerID := model.NewId() + policyID := model.NewId() + + expr := `user.attributes.` + fieldName + ` == "Engineering"` + + storedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionMembership}, Expression: expr}, + }, + } + // Caller legitimately changes the action on a rule with no masked values. + submittedPolicy := &model.AccessControlPolicy{ + ID: policyID, + Type: model.AccessControlPolicyTypeParent, + Rules: []model.AccessControlPolicyRule{ + {Actions: []string{model.AccessControlPolicyActionUploadFileAttachment}, Expression: expr}, + }, + } + + mockACS := &mocks.AccessControlServiceInterface{} + th.App.Srv().ch.AccessControl = mockACS + + mockACS.On("GetPolicy", mock.Anything, policyID).Return(storedPolicy, nil).Once() + mockACS.On("ExpressionToVisualAST", mock.Anything, expr).Return(&model.VisualExpression{ + Conditions: []model.Condition{ + {Attribute: "user.attributes." + fieldName, Operator: "==", Value: "Engineering", ValueType: model.LiteralValue}, + }, + }, nil).Maybe() + + appErr = th.App.mergeStoredPolicyExpressions(th.Context, submittedPolicy, callerID) + require.Nil(t, appErr) + + require.Len(t, submittedPolicy.Rules, 1) + // Expression unchanged (no masking, submitted passes through). + assert.Equal(t, expr, submittedPolicy.Rules[0].Expression) + // Actions must NOT be locked — caller's submitted value stands. + assert.Equal(t, []string{model.AccessControlPolicyActionUploadFileAttachment}, submittedPolicy.Rules[0].Actions) + mockACS.AssertExpectations(t) +} diff --git a/server/channels/app/properties/access_control_field_test.go b/server/channels/app/properties/access_control_field_test.go index fbb4e597605..e940a3a830e 100644 --- a/server/channels/app/properties/access_control_field_test.go +++ b/server/channels/app/properties/access_control_field_test.go @@ -1506,7 +1506,3 @@ func TestLinkedPropertyField_SecurityInheritance(t *testing.T) { assert.False(t, model.IsPropertyFieldProtected(linked)) }) } - -// The previous "member-writable shared_only" early-return in applyFieldReadAccessControl -// was removed in favor of rejecting that contradictory configuration at validation time -// (see TestValidatePropertyFieldAccessMode in server/public/model/property_access_test.go). diff --git a/server/channels/store/sqlstore/access_control_policy_store.go b/server/channels/store/sqlstore/access_control_policy_store.go index 0ae4983afeb..d2324f6f1ea 100644 --- a/server/channels/store/sqlstore/access_control_policy_store.go +++ b/server/channels/store/sqlstore/access_control_policy_store.go @@ -79,8 +79,12 @@ func (s *storeAccessControlPolicy) toModel() (*model.AccessControlPolicy, error) } func fromModel(policy *model.AccessControlPolicy) (*storeAccessControlPolicy, error) { + imports := policy.Imports + if imports == nil { + imports = []string{} + } data, err := json.Marshal(&accessControlPolicyV0_1{ - Imports: policy.Imports, + Imports: imports, Rules: policy.Rules, Roles: policy.Roles, Scope: policy.Scope, @@ -706,7 +710,7 @@ func (s *SqlAccessControlPolicyStore) SearchPolicies(rctx request.CTX, opts mode condition := sq.Expr(`Id IN ( SELECT parent_id FROM ( - SELECT ch.TeamId, jsonb_array_elements_text(cp.Data -> 'imports') AS parent_id + SELECT ch.TeamId, jsonb_array_elements_text(COALESCE(NULLIF(cp.Data -> 'imports', 'null'::jsonb), '[]'::jsonb)) AS parent_id FROM AccessControlPolicies cp JOIN Channels ch ON ch.Id = cp.Id WHERE cp.Type = 'channel' diff --git a/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.scss b/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.scss index f90ee84f51e..b84f0b3e328 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.scss +++ b/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.scss @@ -1,6 +1,24 @@ .cel-editor { margin-bottom: 24px; + &__masked-banner { + display: flex; + align-items: center; + padding: 10px 16px; + border-radius: 4px; + margin-bottom: 8px; + background-color: rgba(var(--center-channel-color-rgb), 0.08); + color: rgba(var(--center-channel-color-rgb), 0.72); + font-size: 13px; + gap: 8px; + line-height: 20px; + + .icon { + color: var(--dnd-indicator); + font-size: 18px; + } + } + &__container { position: relative; display: flex; diff --git a/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.tsx b/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.tsx index bdfe077bc8b..9326d718682 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/cel_editor/editor.tsx @@ -80,6 +80,7 @@ interface CELEditorProps { attribute: string; values: string[]; }>; + hasMaskedRows?: boolean; } // TODO: this is just a sample schema for the editor, we need to get the actual schema from the server @@ -94,6 +95,7 @@ function CELEditor({ teamId, disabled = false, userAttributes, + hasMaskedRows = false, }: CELEditorProps): JSX.Element { const intl = useIntl(); const [editorState, setEditorState] = useState({ @@ -257,12 +259,12 @@ function CELEditor({ }; }, []); // Only run once on mount - // Update the editor's readOnly state when disabled prop changes + // Update the editor's readOnly state when disabled or hasMaskedRows changes useEffect(() => { if (monacoRef.current) { - monacoRef.current.updateOptions({readOnly: disabled}); + monacoRef.current.updateOptions({readOnly: disabled || hasMaskedRows}); } - }, [disabled]); + }, [disabled, hasMaskedRows]); // Helper function to determine current validation state const getValidationState = useCallback(() => { @@ -338,6 +340,19 @@ function CELEditor({
+ {hasMaskedRows && ( +
+ + +
+ )} +
setEditorState((prev) => ({...prev, showTestResults: true}))} - disabled={disabled || !editorState.expression || !editorState.isValid || editorState.isValidating} + disabled={disabled || hasMaskedRows || !editorState.expression || !editorState.isValid || editorState.isValidating} + disabledTooltip={ + hasMaskedRows ? + intl.formatMessage({ + id: 'admin.access_control.cel_editor.masked_values_tooltip', + defaultMessage: 'Test is unavailable because this policy contains restricted attribute values.', + }) : + undefined + } />
{editorState.showTestResults && ( diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.test.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.test.tsx new file mode 100644 index 00000000000..c657c435e1d --- /dev/null +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.test.tsx @@ -0,0 +1,45 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; + +import {renderWithContext, screen} from 'tests/react_testing_utils'; + +import MaskedChip from './masked_chip'; + +describe('MaskedChip', () => { + test('renders the masked token text', () => { + renderWithContext(); + expect(screen.getByText('••••••••')).toBeInTheDocument(); + }); + + test('has role="img" for accessibility', () => { + renderWithContext(); + const chip = screen.getByRole('img'); + expect(chip).toBeInTheDocument(); + }); + + test('has correct aria-label', () => { + renderWithContext(); + const chip = screen.getByRole('img'); + expect(chip).toHaveAttribute('aria-label', 'Hidden values that you do not have permission to view'); + }); + + test('does not have aria-readonly (invalid for role="img")', () => { + renderWithContext(); + const chip = screen.getByRole('img'); + expect(chip).not.toHaveAttribute('aria-readonly'); + }); + + test('does not render a close/remove button', () => { + renderWithContext(); + const removeButtons = document.querySelectorAll('.select__multi-value__remove'); + expect(removeButtons).toHaveLength(0); + }); + + test('has the masked CSS class', () => { + renderWithContext(); + const chip = screen.getByRole('img'); + expect(chip).toHaveClass('select__multi-value--masked'); + }); +}); diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.tsx new file mode 100644 index 00000000000..ed16e875d6f --- /dev/null +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/masked_chip.tsx @@ -0,0 +1,40 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; +import {useIntl} from 'react-intl'; + +import {WithTooltip} from '@mattermost/shared/components/tooltip'; + +import './selector_menus.scss'; + +/** Non-interactive chip indicating hidden attribute values exist in this condition. */ +const MaskedChip = (): JSX.Element => { + const {formatMessage} = useIntl(); + + const tooltipText = formatMessage({ + id: 'admin.access_control.masked_chip.tooltip', + defaultMessage: 'One or more restricted values', + }); + + const ariaLabel = formatMessage({ + id: 'admin.access_control.masked_chip.aria_label', + defaultMessage: 'Hidden values that you do not have permission to view', + }); + + return ( + +
+
+ {'••••••••'} +
+
+
+ ); +}; + +export default MaskedChip; diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/multi_value_selector_menu.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/multi_value_selector_menu.tsx index e3d1b4f8294..9fc32e46b4a 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/multi_value_selector_menu.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/multi_value_selector_menu.tsx @@ -12,6 +12,8 @@ import * as Menu from 'components/menu'; import './selector_menus.scss'; +import MaskedChip from './masked_chip'; + // MultiValueSelector handles selection of multiple values (operator 'in') const MultiValueSelector = ({ values, @@ -20,6 +22,7 @@ const MultiValueSelector = ({ options = [], allowCreateValue = false, placeholder, + hasMaskedValues = false, }: { values: string[]; disabled: boolean; @@ -27,6 +30,7 @@ const MultiValueSelector = ({ options?: PropertyFieldOption[]; allowCreateValue?: boolean; placeholder?: string; + hasMaskedValues?: boolean; }) => { const {formatMessage} = useIntl(); const [filter, setFilter] = useState(''); @@ -92,6 +96,15 @@ const MultiValueSelector = ({ // Memoize cell contents to prevent unnecessary re-renders const cellContents = useMemo(() => { if (values.length === 0) { + // When no visible values exist but the row has masked ones, show only the masked chip. + if (hasMaskedValues) { + return ( +
+ +
+ ); + } + let visualPlaceholderText = defaultMultiPlaceholder; if (actualAllowCreateForMenu && options.length === 0) { visualPlaceholderText = defaultCreatePlaceholder; @@ -131,9 +144,10 @@ const MultiValueSelector = ({ )}
))} + {hasMaskedValues && }
); - }, [values, disabled]); + }, [values, disabled, hasMaskedValues]); return (
diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/selector_menus.scss b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/selector_menus.scss index 830b5935523..29b90eb25b0 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/selector_menus.scss +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/selector_menus.scss @@ -86,12 +86,24 @@ .select__multi-value__remove { padding: 0 4px; cursor: pointer; - + &:hover { background-color: rgba(var(--center-channel-color-rgb), 0.16); } } + .select__multi-value--masked { + background-color: rgba(var(--center-channel-color-rgb), 0.08); + cursor: default; + user-select: none; + + .select__multi-value__label { + color: rgba(var(--center-channel-color-rgb), 0.64); + font-family: monospace; + letter-spacing: 1px; + } + } + &__simple-input { width: 100%; height: 40px; diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/single_value_selector_menu.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/single_value_selector_menu.tsx index aafacab1877..47058012ef3 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/single_value_selector_menu.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/single_value_selector_menu.tsx @@ -12,6 +12,8 @@ import * as Menu from 'components/menu'; import Constants from 'utils/constants'; +import MaskedChip from './masked_chip'; + import './selector_menus.scss'; // SingleValueSelector handles selection of a single value (operators like 'is', 'contains', etc.) @@ -22,6 +24,7 @@ const SingleValueSelector = ({ options = [], allowCreateValue = false, placeholder, + hasMaskedValues = false, }: { value: string; disabled: boolean; @@ -29,6 +32,7 @@ const SingleValueSelector = ({ options?: PropertyFieldOption[]; allowCreateValue?: boolean; placeholder?: string; + hasMaskedValues?: boolean; }) => { const {formatMessage} = useIntl(); const [filter, setFilter] = useState(''); @@ -95,6 +99,21 @@ const SingleValueSelector = ({ } }, [allowCreateValue, filter, handleCreateValue]); + // When masked values are present and the caller holds no visible value, + // the row is effectively read-only — show only the masked chip. + // Placed AFTER hook declarations so hook order stays stable when the + // masked state changes between renders (e.g., parent re-renders after + // a sibling rule is deleted). + if (hasMaskedValues && !value) { + return ( +
+
+ +
+
+ ); + } + if (!hasOptions) { // For attributes without options, show simple input field return ( diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.scss b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.scss index 910a25a7198..755a4045666 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.scss +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.scss @@ -46,10 +46,15 @@ background: none; color: rgba(var(--center-channel-color-rgb), 0.56); cursor: pointer; - - &:hover { + + &:hover:not(:disabled) { color: var(--error-text); } + + &:disabled { + cursor: not-allowed; + opacity: 0.4; + } } &__actions-row { diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.test.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.test.tsx index 25e59f1fc6d..1c4f07d5b24 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.test.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.test.tsx @@ -27,6 +27,7 @@ describe('parseExpression', () => { operator: 'is', values: ['Engineering'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -50,6 +51,7 @@ describe('parseExpression', () => { operator: 'in', values: ['US', 'CA'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -73,6 +75,7 @@ describe('parseExpression', () => { operator: 'is not', values: ['guest'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -96,6 +99,7 @@ describe('parseExpression', () => { operator: 'starts with', values: ['admin'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -126,12 +130,14 @@ describe('parseExpression', () => { operator: 'starts with', values: ['admin'], attribute_type: 'text', + hasMaskedValues: false, }, { attribute: 'department', operator: 'is', values: ['Engineering'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -155,6 +161,7 @@ describe('parseExpression', () => { operator: 'is', values: ['foo'], attribute_type: 'text', + hasMaskedValues: false, }, ]); }); @@ -164,6 +171,44 @@ describe('parseExpression', () => { expect(parseExpression(undefined as any)).toEqual([]); expect(parseExpression({conditions: []})).toEqual([]); }); + + test('sets hasMaskedValues=true when condition has has_masked_values flag', () => { + const ast: AccessControlVisualAST = { + conditions: [ + { + attribute: 'user.attributes.program', + operator: 'in', + value: ['Alpha'], + value_type: 0, + attribute_type: 'text', + has_masked_values: true, + }, + { + attribute: 'user.attributes.clearance', + operator: 'in', + value: [], + value_type: 0, + attribute_type: 'text', + has_masked_values: true, + }, + { + attribute: 'user.attributes.department', + operator: '==', + value: 'Engineering', + value_type: 0, + attribute_type: 'text', + }, + ], + }; + + const rows = parseExpression(ast); + expect(rows[0].hasMaskedValues).toBe(true); // partial: caller holds Alpha + expect(rows[0].values).toEqual(['Alpha']); + expect(rows[1].hasMaskedValues).toBe(true); // fully masked: caller holds nothing + expect(rows[1].values).toEqual([]); + expect(rows[2].hasMaskedValues).toBe(false); // no masking + expect(rows[2].values).toEqual(['Engineering']); + }); }); describe('parseExpression with multiselect attributes', () => { @@ -186,6 +231,7 @@ describe('parseExpression with multiselect attributes', () => { operator: 'has all of', values: ['JavaScript', 'Python'], attribute_type: 'multiselect', + hasMaskedValues: false, }, ]); }); @@ -209,6 +255,7 @@ describe('parseExpression with multiselect attributes', () => { operator: 'has any of', values: ['Dragon', 'Phoenix'], attribute_type: 'multiselect', + hasMaskedValues: false, }, ]); }); @@ -232,6 +279,7 @@ describe('parseExpression with multiselect attributes', () => { operator: 'has all of', values: ['JavaScript'], attribute_type: 'multiselect', + hasMaskedValues: false, }, ]); }); @@ -348,6 +396,7 @@ describe('rowToCEL', () => { operator: 'has any of', values: ['Dragon', 'Phoenix'], attribute_type: 'multiselect', + hasMaskedValues: false, }); expect(cel).toBe('("Dragon" in user.attributes.programs || "Phoenix" in user.attributes.programs)'); }); @@ -358,6 +407,7 @@ describe('rowToCEL', () => { operator: 'has all of', values: ['Dragon', 'Phoenix'], attribute_type: 'multiselect', + hasMaskedValues: false, }); expect(cel).toBe('"Dragon" in user.attributes.programs && "Phoenix" in user.attributes.programs'); }); @@ -368,6 +418,7 @@ describe('rowToCEL', () => { operator: 'has any of', values: ['Dragon'], attribute_type: 'multiselect', + hasMaskedValues: false, }); expect(cel).toBe('"Dragon" in user.attributes.programs'); }); @@ -378,6 +429,7 @@ describe('rowToCEL', () => { operator: 'has all of', values: ['admin'], attribute_type: 'multiselect', + hasMaskedValues: false, }); expect(cel).toBe('"admin" in user.attributes.tags'); }); @@ -388,6 +440,7 @@ describe('rowToCEL', () => { operator: 'in', values: ['Eng', 'Ops'], attribute_type: 'select', + hasMaskedValues: false, }); expect(cel).toBe('user.attributes.department in ["Eng", "Ops"]'); }); @@ -398,6 +451,7 @@ describe('rowToCEL', () => { operator: 'is', values: ['TopSecret'], attribute_type: 'text', + hasMaskedValues: false, }); expect(cel).toBe('user.attributes.clearance == "TopSecret"'); }); @@ -408,6 +462,7 @@ describe('rowToCEL', () => { operator: 'contains', values: ['@example.com'], attribute_type: 'text', + hasMaskedValues: false, }); expect(cel).toBe('user.attributes.email.contains("@example.com")'); }); @@ -418,9 +473,42 @@ describe('rowToCEL', () => { operator: 'is', values: ['O\'Brien\'s "Team"'], attribute_type: 'text', + hasMaskedValues: false, }); expect(cel).toBe('user.attributes.team == "O\'Brien\'s \\"Team\\""'); }); + + // --- Masking-related tests --- + + test('fully-masked row (hasMaskedValues=true, values=[]) emits "in []" placeholder regardless of operator', () => { + // The placeholder is needed so the backend merge can locate this condition + // by attribute and re-inject the hidden values. The operator is irrelevant + // because the backend always overrides it from the stored expression. + const operators = ['in', 'is', 'has all of', 'has any of', 'contains', 'starts with']; + for (const operator of operators) { + const cel = rowToCEL({ + attribute: 'program', + operator, + values: [], + attribute_type: 'text', + hasMaskedValues: true, + }); + expect(cel).toBe('user.attributes.program in []'); + } + }); + + test('partially-masked row (hasMaskedValues=true, values non-empty) uses normal CEL path', () => { + // The caller holds "Alpha"; "Bravo" and "Charlie" are hidden. + // The row should emit the visible value normally — backend merge appends the rest. + const cel = rowToCEL({ + attribute: 'program', + operator: 'in', + values: ['Alpha'], + attribute_type: 'text', + hasMaskedValues: true, + }); + expect(cel).toBe('user.attributes.program in ["Alpha"]'); + }); }); describe('isSimpleExpression', () => { diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.tsx index aef6181c239..f7de92128ea 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/table_editor.tsx @@ -1,7 +1,7 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import React, {useState, useEffect, useCallback} from 'react'; +import React, {useState, useEffect, useCallback, useMemo} from 'react'; import {FormattedMessage, useIntl} from 'react-intl'; import type {AccessControlVisualAST} from '@mattermost/types/access_control'; @@ -28,6 +28,16 @@ export function celStringLiteral(val: string): string { } export function rowToCEL(row: TableRow): string { + // A fully-masked row has no visible values on the client side. Emit a + // placeholder "in []" expression so the backend merge can locate this + // condition by attribute and re-inject the hidden values before persisting. + // Without this guard the condition would be filtered out by updateExpression, + // the empty expression would be sent to the server, and buildCELFromConditions + // would return "true" — making the policy wide-open (security regression). + if (row.hasMaskedValues && row.values.length === 0) { + return `user.attributes.${row.attribute} in []`; + } + const attributeExpr = `user.attributes.${row.attribute}`; const config = OPERATOR_CONFIG[row.operator]; @@ -82,6 +92,9 @@ interface TableEditorProps { // Props for user self-exclusion detection isSystemAdmin?: boolean; validateExpressionAgainstRequester?: (expression: string) => Promise>; + + // Callback to notify parent when masked state changes (for CEL editor integration) + onMaskedStateChange?: (hasMasked: boolean) => void; } // Finds the first available (non-disabled) attribute from a list of user attributes. @@ -131,8 +144,10 @@ export const parseExpression = (visualAST: AccessControlVisualAST): TableRow[] = let values; if (Array.isArray(node.value)) { values = node.value; - } else { + } else if (node.value !== null && node.value !== undefined) { values = [node.value]; + } else { + values = []; } tableRows.push({ @@ -140,12 +155,33 @@ export const parseExpression = (visualAST: AccessControlVisualAST): TableRow[] = operator: op, values, attribute_type: node.attribute_type, + hasMaskedValues: node.has_masked_values === true, }); } return tableRows; }; +function getTestButtonTooltip( + hasMaskedRows: boolean, + userWouldBeExcluded: boolean, + formatMessage: ReturnType['formatMessage'], +): string | undefined { + if (hasMaskedRows) { + return formatMessage({ + id: 'admin.access_control.table_editor.masked_values_tooltip', + defaultMessage: 'Test is unavailable because this policy contains restricted attribute values.', + }); + } + if (userWouldBeExcluded) { + return formatMessage({ + id: 'admin.access_control.table_editor.user_excluded_tooltip', + defaultMessage: 'You cannot test access rules that would exclude you from the channel', + }); + } + return undefined; +} + // TableEditor provides a user-friendly table interface for constructing and editing // CEL (Common Expression Language) expressions based on user attributes. // It parses incoming CEL expressions into rows and reconstructs the expression upon changes. @@ -164,6 +200,7 @@ function TableEditor({ actions, isSystemAdmin = false, validateExpressionAgainstRequester, + onMaskedStateChange, }: TableEditorProps): JSX.Element { const {formatMessage} = useIntl(); @@ -175,10 +212,18 @@ function TableEditor({ // State for user self-exclusion detection (only applies to non-system-admins) const [userWouldBeExcluded, setUserWouldBeExcluded] = useState(false); - // Effect to parse the incoming CEL expression string (value prop) - // and update the internal rows state. Handles errors during parsing. + // Derived state: whether any row has masked values + const hasMaskedRows = useMemo(() => rows.some((r) => r.hasMaskedValues), [rows]); + + // Prevents getVisualAST re-parse when expression change is from internal row editing. + const isInternalChange = React.useRef(false); + useEffect(() => { - // Skip parsing if no expression to avoid unnecessary API calls + if (isInternalChange.current) { + isInternalChange.current = false; + return; + } + if (!value || value.trim() === '') { setRows([]); return; @@ -209,10 +254,8 @@ function TableEditor({ }); }, [value]); - // Effect to check if user would be excluded by their own rules useEffect(() => { const checkUserSelfExclusion = async () => { - // Only check for non-system admins when there's an expression and validation function if (isSystemAdmin || !value.trim() || !validateExpressionAgainstRequester) { setUserWouldBeExcluded(false); return; @@ -221,8 +264,7 @@ function TableEditor({ try { const result = await validateExpressionAgainstRequester(value); setUserWouldBeExcluded(!result.data?.requester_matches); - } catch (error) { - // If validation fails, assume they would not be excluded (to allow testing) + } catch { setUserWouldBeExcluded(false); } }; @@ -230,28 +272,30 @@ function TableEditor({ checkUserSelfExclusion(); }, [value, isSystemAdmin, validateExpressionAgainstRequester]); - // Converts the internal rows state back into a CEL expression string - // and calls the onChange and onValidate props. + useEffect(() => { + onMaskedStateChange?.(hasMaskedRows); + }, [hasMaskedRows, onMaskedStateChange]); + const updateExpression = useCallback((newRows: TableRow[]) => { - const rowsThatCanFormExpressions = newRows.filter((row) => row.attribute && row.values.length > 0); + // Include masked rows with no visible values: rowToCEL will emit an "in []" + // placeholder so the backend merge can restore the hidden values on save. + const rowsThatCanFormExpressions = newRows.filter((row) => row.attribute && (row.values.length > 0 || row.hasMaskedValues)); const expr = rowsThatCanFormExpressions.map((row) => rowToCEL(row)).join(' && '); + isInternalChange.current = true; onChange(expr); if (onValidate) { onValidate(expr === '' || rowsThatCanFormExpressions.length > 0); } }, [onChange, onValidate]); - // Helper function to find the first available (non-disabled) attribute const findFirstAvailableAttribute = useCallback(() => { return findFirstAvailableAttributeFromList(userAttributes, enableUserManagedAttributes); }, [userAttributes, enableUserManagedAttributes]); - // Row Manipulation Handlers const addRow = useCallback(() => { if (userAttributes.length === 0) { - // Show a helpful message instead of silently failing onParseError('No user attributes available. Please ensure ABAC is properly configured and you have the necessary permissions.'); return; } @@ -263,11 +307,12 @@ function TableEditor({ } setRows((currentRows) => { - const newRow = { + const newRow: TableRow = { attribute: firstAvailableAttribute.name, operator: firstAvailableAttribute.type === 'multiselect' ? OperatorLabel.HAS_ANY_OF : OperatorLabel.IS, values: [], attribute_type: firstAvailableAttribute.type || '', + hasMaskedValues: false, }; const newRows = [...currentRows, newRow]; updateExpression(newRows); // Ensure expression is updated immediately @@ -284,6 +329,12 @@ function TableEditor({ }); }, [updateExpression]); + const requestRemoveRow = useCallback((index: number) => { + // Masked rows have their remove button disabled — the row is read-only + // because the server would 403 on a delete that strips hidden values. + removeRow(index); + }, [removeRow]); + const updateRowAttribute = useCallback((index: number, attribute: string) => { setRows((currentRows) => { const newRows = [...currentRows]; @@ -406,7 +457,7 @@ function TableEditor({ updateRowAttribute(index, attribute)} menuId={`attribute-selector-menu-${index}`} buttonId={`attribute-selector-button-${index}`} @@ -418,7 +469,7 @@ function TableEditor({ updateRowOperator(index, operator)} attributeType={userAttributes.find((attr) => attr.name === row.attribute)?.type} /> @@ -426,7 +477,7 @@ function TableEditor({ updateRowValues(index, values)} options={row.attribute ? userAttributes.find((attr) => attr.name === row.attribute)?.attrs?.options || [] : []} /> @@ -435,8 +486,8 @@ function TableEditor({
diff --git a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/value_selector_menu.tsx b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/value_selector_menu.tsx index 4d8aec425e3..ec3426ea596 100644 --- a/webapp/channels/src/components/admin_console/access_control/editors/table_editor/value_selector_menu.tsx +++ b/webapp/channels/src/components/admin_console/access_control/editors/table_editor/value_selector_menu.tsx @@ -15,6 +15,7 @@ export interface TableRow { operator: string; values: string[]; attribute_type: string; + hasMaskedValues: boolean; } export interface ValueSelectorMenuProps { @@ -26,7 +27,6 @@ export interface ValueSelectorMenuProps { placeholder?: string; } -// Main ValueSelectorMenu component that delegates to the appropriate selector const ValueSelectorMenu = ({ row, disabled, @@ -46,11 +46,11 @@ const ValueSelectorMenu = ({ options={options} allowCreateValue={allowCreateValue} placeholder={placeholder} + hasMaskedValues={row.hasMaskedValues} /> ); } - // For single-value operators return ( ); }; diff --git a/webapp/channels/src/components/admin_console/access_control/policies.test.tsx b/webapp/channels/src/components/admin_console/access_control/policies.test.tsx index c36f7a30298..c4ffb43c115 100644 --- a/webapp/channels/src/components/admin_console/access_control/policies.test.tsx +++ b/webapp/channels/src/components/admin_console/access_control/policies.test.tsx @@ -143,6 +143,88 @@ describe('components/admin_console/access_control/PolicyList', () => { expect(screen.getByText('Delete')).toBeInTheDocument(); }); + test('Delete menu item is disabled when a policy contains masked values', async () => { + // The "--------" sentinel in a rule expression means the caller can't + // see at least one referenced value. Server enforces a 403 on delete + // in that case, so the menu item must be disabled to avoid a useless + // confirmation modal → 403 round-trip. + mockSearchPolicies.mockResolvedValue({ + data: { + policies: [{ + id: 'masked-policy', + name: 'Masked Policy', + rules: [{actions: ['*'], expression: 'user.attributes.program in ["Alpha", "--------"]'}], + } as unknown as AccessControlPolicy], + total: 1, + }, + } as ActionResult); + renderWithContext(); + await waitFor(() => { + expect(screen.getByText('Masked Policy')).toBeInTheDocument(); + }); + + const menuButton = document.getElementById('policy-menu-masked-policy')!; + await userEvent.click(menuButton); + + const deleteItem = document.getElementById('policy-menu-delete-masked-policy')!; + expect(deleteItem).toHaveAttribute('aria-disabled', 'true'); + }); + + test('Delete menu item is enabled for a clean policy with no channels', async () => { + // Sanity: a policy that has neither child channels nor masked values + // must keep Delete enabled. + mockSearchPolicies.mockResolvedValue({ + data: { + policies: [{ + id: 'clean-policy', + name: 'Clean Policy', + rules: [{actions: ['*'], expression: 'user.attributes.program in ["Alpha"]'}], + } as unknown as AccessControlPolicy], + total: 1, + }, + } as ActionResult); + renderWithContext(); + await waitFor(() => { + expect(screen.getByText('Clean Policy')).toBeInTheDocument(); + }); + + const menuButton = document.getElementById('policy-menu-clean-policy')!; + await userEvent.click(menuButton); + + const deleteItem = document.getElementById('policy-menu-delete-clean-policy')!; + expect(deleteItem).not.toHaveAttribute('aria-disabled', 'true'); + }); + + test('Delete confirmation modal no longer surfaces the masked-values warning', async () => { + // The inner-modal "This policy contains restricted values" notice was + // removed once we started disabling the Delete menu item upstream. + // Open the modal on a clean policy and assert the warning text is gone. + mockSearchPolicies.mockResolvedValue({ + data: { + policies: [{ + id: 'clean-policy', + name: 'Clean Policy', + rules: [{actions: ['*'], expression: 'user.attributes.program in ["Alpha"]'}], + } as unknown as AccessControlPolicy], + total: 1, + }, + } as ActionResult); + renderWithContext(); + await waitFor(() => { + expect(screen.getByText('Clean Policy')).toBeInTheDocument(); + }); + + const menuButton = document.getElementById('policy-menu-clean-policy')!; + await userEvent.click(menuButton); + await userEvent.click(screen.getByText('Delete')); + + await waitFor(() => { + expect(screen.getByText('Confirm Policy Deletion')).toBeInTheDocument(); + }); + expect(screen.queryByText(/This policy includes attribute values that are hidden from you/)).not.toBeInTheDocument(); + expect(screen.queryByText('This policy contains restricted values')).not.toBeInTheDocument(); + }); + test('should get columns correctly', async () => { mockSearchPolicies.mockResolvedValue({data: {policies: [], total: 0}} as ActionResult); renderWithContext(); diff --git a/webapp/channels/src/components/admin_console/access_control/policies.tsx b/webapp/channels/src/components/admin_console/access_control/policies.tsx index 0452dd9909a..795566a585a 100644 --- a/webapp/channels/src/components/admin_console/access_control/policies.tsx +++ b/webapp/channels/src/components/admin_console/access_control/policies.tsx @@ -1,9 +1,10 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import React, {useState, useEffect, useMemo} from 'react'; +import React, {useState, useEffect, useMemo, useCallback} from 'react'; import {FormattedMessage, useIntl} from 'react-intl'; +import {GenericModal} from '@mattermost/components'; import {Button} from '@mattermost/shared/components/button'; import type {AccessControlPolicy} from '@mattermost/types/access_control'; @@ -12,11 +13,23 @@ import type {ActionResult} from 'mattermost-redux/types/actions'; import type {Row, Column} from 'components/admin_console/data_grid/data_grid'; import DataGrid from 'components/admin_console/data_grid/data_grid'; import * as Menu from 'components/menu'; +import SectionNotice from 'components/section_notice'; import {getHistory} from 'utils/browser_history'; import './policies.scss'; +// The server emits the eight-dash masked-token sentinel inside raw CEL expressions +// when masking values the caller cannot see (e.g. `attr == "--------"`). The full +// visual AST carries a typed `has_masked_values` flag per condition, but on the +// policies list page we only have the raw expression strings — so we detect masking +// by the quoted token substring. +const MASKED_VALUE_TOKEN_LITERAL = '"--------"'; + +function policyHasMaskedValues(policy: AccessControlPolicy): boolean { + return policy.rules?.some((rule) => rule.expression?.includes(MASKED_VALUE_TOKEN_LITERAL)) ?? false; +} + type Props = { onPolicySelected?: (policy: AccessControlPolicy) => void; onPoliciesLoaded?: (count: number) => void; @@ -41,6 +54,8 @@ export default function PolicyList(props: Props): JSX.Element { const [searchErrored, setSearchErrored] = useState(false); const [cursorHistory, setCursorHistory] = useState([]); const [total, setTotal] = useState(0); + const [pendingDeletePolicy, setPendingDeletePolicy] = useState(null); + const [deleteError, setDeleteError] = useState(null); const intl = useIntl(); const history = useMemo(() => getHistory(), []); @@ -157,10 +172,38 @@ export default function PolicyList(props: Props): JSX.Element { ); }; - const handleDelete = async (policyId: string) => { - await props.actions.deletePolicy(policyId); + const initiateDelete = useCallback((policy: AccessControlPolicy) => { + setPendingDeletePolicy(policy); + setDeleteError(null); + }, []); + + const confirmDelete = useCallback(async () => { + if (!pendingDeletePolicy) { + return; + } + const result = await props.actions.deletePolicy(pendingDeletePolicy.id); + if (result?.error) { + // The server enforces masked-policy / permission rejections (403). Surface + // the message in the modal so the user sees why deletion failed instead of + // a silent close + stale list refresh. + const errorId = result.error.server_error_id; + if (errorId === 'app.pap.delete_policy.masked_values') { + setDeleteError(intl.formatMessage({ + id: 'admin.access_control.delete_policy.masked_values', + defaultMessage: 'You cannot delete this policy because it contains attribute values you do not have permission to view.', + })); + } else { + setDeleteError(result.error.message || intl.formatMessage({ + id: 'admin.access_control.delete_policy.generic_error', + defaultMessage: 'Failed to delete the policy.', + })); + } + return; + } + setPendingDeletePolicy(null); + setDeleteError(null); fetchPolicies(search); - }; + }, [pendingDeletePolicy, search, intl]); const getRows = (): Row[] => { return policies.map((policy: AccessControlPolicy) => { @@ -223,7 +266,7 @@ export default function PolicyList(props: Props): JSX.Element { {!props.hideDeleteAction && ( handleDelete(policy.id)} + onClick={() => initiateDelete(policy)} leadingElement={} labels={ } isDestructive={true} - disabled={Boolean(policy.props?.child_ids?.length)} + + // Also disable when the policy contains values masked + // for this caller. Mirrors the policy-details Delete + // guard — server returns 403, so otherwise the modal + // flow would just round-trip an error. + disabled={Boolean(policy.props?.child_ids?.length) || policyHasMaskedValues(policy)} /> )} @@ -404,6 +452,54 @@ export default function PolicyList(props: Props): JSX.Element { ) : undefined} /> + {pendingDeletePolicy && ( + { + setPendingDeletePolicy(null); + setDeleteError(null); + }} + handleConfirm={confirmDelete} + handleCancel={() => { + setPendingDeletePolicy(null); + setDeleteError(null); + }} + modalHeaderText={ + + } + confirmButtonText={ + + } + confirmButtonVariant='destructive' + compassDesign={true} + > + <> + + {deleteError && ( +
+ + } + text={deleteError} + /> +
+ )} + +
+ )} ); } diff --git a/webapp/channels/src/components/admin_console/access_control/policy_details/__snapshots__/policy_details.test.tsx.snap b/webapp/channels/src/components/admin_console/access_control/policy_details/__snapshots__/policy_details.test.tsx.snap index 7c9f4c8860e..b9e4021c4f8 100644 --- a/webapp/channels/src/components/admin_console/access_control/policy_details/__snapshots__/policy_details.test.tsx.snap +++ b/webapp/channels/src/components/admin_console/access_control/policy_details/__snapshots__/policy_details.test.tsx.snap @@ -87,99 +87,8 @@ exports[`components/admin_console/access_control/policy_details/PolicyDetails sh style="height: 0px;" >
- - - - - - - - - - - - - - - - - - -
- Attribute - - Operator - - - Values - - -
- - Select a user attribute and values to create a rule - -
- -
-
-
-

- Each row is a single condition that must be met for a user to comply with the policy. All rules are combined with logical AND operator ( - - - && - - - ). -

-
- -
-
+ data-testid="table-editor" + />
- - - - - - - - - - - - - - - - - - -
- Attribute - - Operator - - - Values - - -
- - Select a user attribute and values to create a rule - -
- -
-
-
-

- Each row is a single condition that must be met for a user to comply with the policy. All rules are combined with logical AND operator ( - - - && - - - ). -

-
- -
-
+ data-testid="table-editor" + />
({ }), })); +// Mock TableEditor so tests can control onMaskedStateChange callbacks. +// jest.mock factory may not reference out-of-scope variables, so React is required inline. +jest.mock('../editors/table_editor/table_editor', () => { + const reactLib = require('react'); + return jest.fn(({onMaskedStateChange}: any) => { + reactLib.useEffect(() => { + onMaskedStateChange?.(false); + }, []); + return reactLib.createElement('div', {'data-testid': 'table-editor'}); + }); +}); + +// Mock CELEditor — its real implementation boots Monaco on mount, which is +// not available in JSDOM. The mode-toggle tests only care that switching to +// Advanced/Simple flips state in the parent, not how Monaco renders. +jest.mock('../editors/cel_editor/editor', () => { + const reactLib = require('react'); + return jest.fn(() => reactLib.createElement('div', {'data-testid': 'cel-editor'})); +}); + // Mock the useChannelAccessControlActions hook jest.mock('hooks/useChannelAccessControlActions', () => ({ useChannelAccessControlActions: jest.fn(), @@ -151,11 +171,9 @@ describe('components/admin_console/access_control/policy_details/PolicyDetails', ...defaultProps.actions, fetchPolicy: jest.fn().mockResolvedValue({ data: { - policy: { - id: 'policy1', - name: 'Policy 1', - rules: [{expression: 'true'}], - }, + id: 'policy1', + name: 'Policy 1', + rules: [{expression: 'true'}], }, }), }, @@ -164,6 +182,140 @@ describe('components/admin_console/access_control/policy_details/PolicyDetails', expect(container).toMatchSnapshot(); }); + test('should show masked values warning banner when policy has masked rows', async () => { + // hasMaskedRows is derived in policy_details from the presence of the + // "--------" sentinel in the loaded expression — drive the test via a + // fetched policy carrying a masked rule. + const props = { + ...defaultProps, + actions: { + ...defaultProps.actions, + fetchPolicy: jest.fn().mockResolvedValue({ + data: { + id: 'policy1', + name: 'Policy 1', + rules: [{ + actions: ['*'], + expression: 'user.attributes.program in ["Alpha", "--------"]', + }], + }, + }), + }, + }; + renderWithContext(); + + await waitFor(() => { + expect(screen.getByText('This policy contains restricted values')).toBeInTheDocument(); + }); + expect(screen.getByText(/Some rules include attribute values you cannot see/)).toBeInTheDocument(); + }); + + test('should not show masked values warning banner when no masked rows', async () => { + renderWithContext(); + + await waitFor(() => { + expect(screen.queryByText('This policy contains restricted values')).not.toBeInTheDocument(); + }); + }); + + test('hasMaskedRows derivation survives Simple → Advanced → Simple mode toggles', async () => { + // Regression guard: hasMaskedRows must come from the expression itself, + // not from a TableEditor lifecycle callback. Toggling editor modes + // remounts TableEditor; if the parent reset hasMaskedRows on remount, + // the warning banner would flicker off and the CEL/delete gates would + // briefly open. Deriving from the "--------" sentinel in the expression + // is the only source of truth that's lifecycle-independent. + + // The mode-toggle button is disabled while no usable attributes are + // available, so the test needs at least one to actually exercise the + // Simple → Advanced → Simple round-trip. + mockGetAccessControlFields.mockResolvedValue({data: [{name: 'program', attrs: {ldap: true}}]}); + + const props = { + ...defaultProps, + actions: { + ...defaultProps.actions, + fetchPolicy: jest.fn().mockResolvedValue({ + data: { + id: 'policy1', + name: 'Policy 1', + rules: [{ + actions: ['*'], + expression: 'user.attributes.program in ["Alpha", "--------"]', + }], + }, + }), + }, + }; + renderWithContext(); + + // Banner present after initial load (Simple mode). + await waitFor(() => { + expect(screen.getByText('This policy contains restricted values')).toBeInTheDocument(); + }); + + // Switch to Advanced mode — banner must remain (it lives outside the + // editor swap, gated by hasMaskedRows which is expression-derived). + const toAdvanced = screen.getByText('Switch to Advanced Mode'); + await userEvent.click(toAdvanced); + expect(screen.getByText('This policy contains restricted values')).toBeInTheDocument(); + + // Switch back to Simple mode — banner must STILL be there. Before the + // fix, the TableEditor remount transiently flipped hasMaskedRows to + // false and the banner disappeared. + const toSimple = screen.getByText('Switch to Simple Mode'); + await userEvent.click(toSimple); + expect(screen.getByText('This policy contains restricted values')).toBeInTheDocument(); + }); + + test('hasMaskedRows stays false for a policy without the masked-token sentinel', async () => { + // Negative case: a normal policy expression must not trip the + // masked-rows banner. + const props = { + ...defaultProps, + actions: { + ...defaultProps.actions, + fetchPolicy: jest.fn().mockResolvedValue({ + data: { + id: 'policy1', + name: 'Policy 1', + rules: [{ + actions: ['*'], + expression: 'user.attributes.program in ["Alpha", "Bravo"]', + }], + }, + }), + }, + }; + renderWithContext(); + + await waitFor(() => { + expect(screen.getByText('Delete policy')).toBeInTheDocument(); + }); + expect(screen.queryByText('This policy contains restricted values')).not.toBeInTheDocument(); + }); + + // Note: when hasMaskedRows is true the Delete button is disabled (policy_details.tsx), + // so the masked-warning inside the confirmation modal is defense-in-depth and not + // reachable through normal UI flow. Test only the no-masked-rows path here. + + test('should not show masked values warning in delete confirmation modal when no masked rows', async () => { + renderWithContext(); + + await waitFor(() => { + expect(screen.getByText('Delete policy')).toBeInTheDocument(); + }); + + const deleteButtons = screen.getAllByText('Delete'); + await userEvent.click(deleteButtons[deleteButtons.length - 1]); + + await waitFor(() => { + expect(screen.getByText('Confirm Policy Deletion')).toBeInTheDocument(); + }); + + expect(screen.queryByText(/This policy includes attribute values that are hidden from you/)).not.toBeInTheDocument(); + }); + test('should handle delete policy', async () => { const props = { ...defaultProps, diff --git a/webapp/channels/src/components/admin_console/access_control/policy_details/policy_details.tsx b/webapp/channels/src/components/admin_console/access_control/policy_details/policy_details.tsx index d278b278563..8dc765642a2 100644 --- a/webapp/channels/src/components/admin_console/access_control/policy_details/policy_details.tsx +++ b/webapp/channels/src/components/admin_console/access_control/policy_details/policy_details.tsx @@ -81,6 +81,16 @@ function PolicyDetails({ const [serverError, setServerError] = useState(undefined); const [addChannelOpen, setAddChannelOpen] = useState(false); const [editorMode, setEditorMode] = useState<'cel' | 'table'>('table'); + + // Derive masked-rows state directly from the expression rather than relying + // on a callback from TableEditor: TableEditor unmounts on mode switches, and + // on remount its rows-derived flag flickers false-then-true while the async + // AST round-trip is in flight, briefly opening gates (CEL read-only, delete, + // banner) that should stay closed. The "--------" sentinel is what the + // server emits in raw CEL for any value the caller can't see, so its + // presence in the expression is a stable signal independent of editor + // lifecycle. + const hasMaskedRows = useMemo(() => expression.includes('"--------"'), [expression]); const [channelChanges, setChannelChanges] = useState({ removed: {}, added: {}, @@ -197,6 +207,14 @@ function PolicyDetails({ if (result.error) { if (result.error.server_error_id === 'app.pap.save_policy.name_exists.app_error') { setServerError(formatMessage({id: 'admin.access_control.edit_policy.name_exists', defaultMessage: 'A policy with this name already exists. Please choose a different name.'})); + } else if (result.error.server_error_id === 'app.pap.save_policy.invalid_value') { + setServerError(formatMessage({id: 'admin.access_control.edit_policy.invalid_value', defaultMessage: 'Invalid value.'})); + } else if (result.error.server_error_id === 'app.pap.save_policy.self_exclusion') { + setServerError(formatMessage({id: 'admin.access_control.edit_policy.self_exclusion', defaultMessage: 'You do not satisfy one or more conditions in this policy. Contact a System Admin for assistance.'})); + } else if (result.error.server_error_id === 'app.pap.save_policy.masked_condition_deleted') { + setServerError(formatMessage({id: 'admin.access_control.edit_policy.masked_condition_deleted', defaultMessage: 'You cannot remove a condition that contains attribute values you do not have permission to view.'})); + } else if (result.error.server_error_id === 'app.pap.save_policy.masked_rule_deleted') { + setServerError(formatMessage({id: 'admin.access_control.edit_policy.masked_rule_deleted', defaultMessage: 'You cannot remove a rule that contains attribute values you do not have permission to view.'})); } else { setServerError(result.error.message); } @@ -506,6 +524,23 @@ function PolicyDetails({ /> + {hasMaskedRows && ( +
+ + } + text={formatMessage({ + id: 'admin.access_control.policy.edit_policy.masked_values_warning.text', + defaultMessage: 'Some rules include attribute values you cannot see. Editing or deleting these rules may change who has access in ways you cannot fully anticipate.', + })} + /> +
+ )} {editorMode === 'cel' ? ( {}} disabled={noUsableAttributes} + hasMaskedRows={hasMaskedRows} userAttributes={autocompleteResult. filter((attr) => { if (accessControlSettings.EnableUserManagedAttributes) { @@ -613,6 +649,23 @@ function PolicyDetails({ expanded={true} className={'console delete-policy'} > + {hasMaskedRows && ( +
+ + } + text={formatMessage({ + id: 'admin.access_control.policy.edit_policy.delete_policy.masked_values_warning.text', + defaultMessage: 'Removing this policy could affect access for users you cannot fully account for.', + })} + /> +
+ )} @@ -698,10 +751,12 @@ function PolicyDetails({ confirmButtonVariant='destructive' compassDesign={true} > - + <> + + )} diff --git a/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.scss b/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.scss index e521b0a2c12..9338a9808c1 100644 --- a/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.scss +++ b/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.scss @@ -2,6 +2,10 @@ padding-bottom: 80px; margin-top: -8px; + &__masked-values-warning { + margin-bottom: 16px; + } + &__header { margin-bottom: 0; } diff --git a/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.tsx b/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.tsx index ecb7b71cea1..e23aec85cea 100644 --- a/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.tsx +++ b/webapp/channels/src/components/team_settings/team_access_policies_tab/team_policy_editor.tsx @@ -107,6 +107,7 @@ export default function TeamPolicyEditor({ const [showConfirmationModal, setShowConfirmationModal] = useState(false); const [showDeleteModal, setShowDeleteModal] = useState(false); const [backClicked, setBackClicked] = useState(false); + const [hasMaskedRows, setHasMaskedRows] = useState(false); const noUsableAttributes = attributesLoaded && !hasUsableAttributes(autocompleteResult, accessControlSettings.EnableUserManagedAttributes); @@ -262,7 +263,7 @@ export default function TeamPolicyEditor({ setSaveChangesPanelState(SAVE_RESULT_ERROR); return false; } - if (expression.includes('== ""') || expression.includes("== ''") || expression.includes('in []')) { + if (expression.includes('== ""') || expression.includes("== ''")) { setFormError(formatMessage({id: 'team_settings.policy_editor.error.incomplete_rule', defaultMessage: 'Please complete all attribute rules with a value'})); setSaveChangesPanelState(SAVE_RESULT_ERROR); return false; @@ -510,6 +511,23 @@ export default function TeamPolicyEditor({

+ {hasMaskedRows && ( +
+ + } + text={formatMessage({ + id: 'admin.access_control.policy.edit_policy.masked_values_warning.text', + defaultMessage: 'Some rules include attribute values you cannot see. Editing or deleting these rules may change who has access in ways you cannot fully anticipate.', + })} + /> +
+ )} @@ -617,7 +636,7 @@ export default function TeamPolicyEditor({ - - } - > - {submitText} - - - - - ); - } -} diff --git a/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.test.ts b/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.test.ts deleted file mode 100644 index 4c672cfa694..00000000000 --- a/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.test.ts +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See LICENSE.txt for license information. - -import type {GlobalState} from '@mattermost/types/store'; - -import {interactiveDialogAppsFormEnabled} from './interactive_dialog'; - -describe('interactive_dialog selectors', () => { - describe('interactiveDialogAppsFormEnabled', () => { - const createMockState = (config: Partial = {}): GlobalState => ({ - entities: { - general: { - config, - }, - }, - } as any); - - test('should return true when feature flag is enabled', () => { - const state = createMockState({ - FeatureFlagInteractiveDialogAppsForm: 'true', - }); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(true); - }); - - test('should return false when feature flag is disabled', () => { - const state = createMockState({ - FeatureFlagInteractiveDialogAppsForm: 'false', - }); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(false); - }); - - test('should return false when feature flag is not present', () => { - const state = createMockState({}); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(false); - }); - - test('should return false when feature flag is empty string', () => { - const state = createMockState({ - FeatureFlagInteractiveDialogAppsForm: '', - }); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(false); - }); - - test('should return false when feature flag is undefined', () => { - const state = createMockState({ - FeatureFlagInteractiveDialogAppsForm: undefined, - }); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(false); - }); - - test('should return false when config is empty', () => { - const state = createMockState(); - - expect(interactiveDialogAppsFormEnabled(state)).toBe(false); - }); - - test('should be case sensitive for true value', () => { - const stateUppercase = createMockState({ - FeatureFlagInteractiveDialogAppsForm: 'TRUE', - }); - - const stateMixed = createMockState({ - FeatureFlagInteractiveDialogAppsForm: 'True', - }); - - expect(interactiveDialogAppsFormEnabled(stateUppercase)).toBe(false); - expect(interactiveDialogAppsFormEnabled(stateMixed)).toBe(false); - }); - }); -}); diff --git a/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.ts b/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.ts deleted file mode 100644 index 246aeace00f..00000000000 --- a/webapp/channels/src/packages/mattermost-redux/src/selectors/entities/interactive_dialog.ts +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See LICENSE.txt for license information. - -import type {ClientConfig} from '@mattermost/types/config'; -import type {GlobalState} from '@mattermost/types/store'; - -import {createSelector} from 'mattermost-redux/selectors/create_selector'; -import {getConfig} from 'mattermost-redux/selectors/entities/general'; - -export const interactiveDialogAppsFormEnabled = createSelector( - 'interactiveDialogAppsFormEnabled', - (state: GlobalState) => getConfig(state), - (config: Partial) => { - return config?.FeatureFlagInteractiveDialogAppsForm === 'true'; - }, -); diff --git a/webapp/platform/types/src/config.ts b/webapp/platform/types/src/config.ts index c5e6fc8899e..1b54ce3f901 100644 --- a/webapp/platform/types/src/config.ts +++ b/webapp/platform/types/src/config.ts @@ -130,7 +130,6 @@ export type ClientConfig = { FeatureFlagAttributeBasedAccessControl: string; FeatureFlagPermissionPolicies: string; FeatureFlagWebSocketEventScope: string; - FeatureFlagInteractiveDialogAppsForm: string; FeatureFlagContentFlagging: string; FeatureFlagClassificationMarkings: string; FeatureFlagManagedChannelCategories: string; From 448a642835da50b82106ec8ad9ae6d0a200f6bae Mon Sep 17 00:00:00 2001 From: Scott Bishel Date: Wed, 20 May 2026 13:34:44 -0600 Subject: [PATCH 44/80] Add inline action buttons for bot-posted markdown (#36219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add inline action buttons for bot-posted markdown Bots, webhooks, and plugins can now embed clickable action buttons inside markdown (including table cells) using mmaction://actionId links, with row-specific parameters forwarded to the integration on click. This enables use cases like a per-row "Mx Plan" button in a fleet-status table that opens a dialog scoped to the clicked row. Design - New post prop inline_actions maps actionId (alphanumeric) to a PostActionIntegration {URL, Context}, capped at 50 entries. - Markdown link with scheme mmaction:// emits a placeholder span that messageHtmlToComponent converts to the InlineActionButton component. - Click POSTs inline_context (parsed from the URL query string) to the existing /posts/{id}/actions/{action_id} endpoint; the server merges it into the integration request as context.inline_params while preserving the post-level context. - Only bot, webhook, and plugin posts render the button; non-integration posts have inline_actions stripped on create, update, and ephemeral broadcast. Hardened-mode also covers the new prop. - Reuses the existing PostAction dialog pipeline: plugin handlers reply with a trigger_id and call /actions/dialogs/open as before. Security - InlineContext capped at 50 entries / 128-char keys / 2 KB values. - Integration Context cloned per click so per-click inline_params and selected_option cannot leak into the cached post for other clickers. - Plugin response updates cannot add inline_actions to a post that did not already have them; invalid entries are dropped with a warn log. - Label content and data attributes are escaped; labels are flattened to plain text (tags stripped, entities decoded, then escaped). - Malformed JSON request bodies now return 400 instead of falling through with an empty inline_context. Tests - Model: validators, normalization, GetInlineAction, strip, fallback. - App: create strip, update guard (4 subtests including AllowInlineActionsUpdate bypass), ephemeral strip, inline_params merge, context-map isolation, plugin-response guards, from_bot and from_plugin retention across plugin updates. - API: inline_context validation (size bounds + error id), omitempty backward compat, malformed JSON 400. - Webapp: renderer scheme handling, allow/deny flags, size caps, HTML escape, tag strip, entity decode, attribute-injection defense; component click dispatch, double-click race guard, unmount safety, error-result recovery, aria state. Co-Authored-By: Claude Opus 4.7 (1M context) * lint fix * i18n-extract * Review fixes for inline action buttons - renderer: preserve actionId case; reject opaque mmaction: URI - app: require bot AND integration session to preserve inline_actions - app: restore original inline_actions when plugin response is invalid - i18n: rename key to ...app_error to match convention Co-Authored-By: Claude Opus 4.7 (1M context) * Tighten UpdatePost inline_actions guard; fix test seeds - app: UpdatePost now requires AllowInlineActionsUpdate to modify inline_actions. Integration session alone is insufficient — a PAT-wielding user could otherwise inject inline_actions on any post they could edit. - tests: seed bot posts with inline_actions via an integration session (intSeedCtx) so they survive the create-time strip. - renderer: lint fix (blank line before comment block). Co-Authored-By: Claude Opus 4.7 (1M context) * Reject malformed inline-action authorities at render time - renderer: enforce ^[A-Za-z0-9]+$ on actionId, mirroring the server regex. Authorities like mmaction://plan:443 or mmaction://user@plan now fall through to plain text instead of rendering a dead button. - post: clarify in the strip comment that webhooks and plugins bypass CreatePostAsUser entirely (they call CreatePost / CreatePostMissingChannel directly), so the strip block does not apply to them. Co-Authored-By: Claude Opus 4.7 (1M context) * Tighten inline-action renderer tests - Replace oversized-params test with boundary pair (at-cap and over-cap) to lock in the > vs >= behavior of the size-limit check. - Add a "surrounding text survives" assertion for the tag-strip path so a future swap from regex strip to a DOM sanitizer won't silently drop legitimate content along with tags. Co-Authored-By: Claude Opus 4.7 (1M context) * Inline action buttons via mmaction:// markdown links Adds inline action buttons rendered from mmaction:// links in markdown, with the click pipeline reusing the existing post-action infrastructure. Aligned with the broader mm_blocks_actions framework (Daniel's PR). * fix lint, DoS hardening, fix and rename test * Address review feedback * lint fix * Reject percent-encoded path traversal in validateIntegrationURL (e.g. %2e%2e%2f) by parsing the URL and checking the decoded path. --------- Co-authored-by: Claude Opus 4.7 (1M context) Co-authored-by: Mattermost Build --- server/channels/api4/integration_action.go | 48 +- .../channels/api4/integration_action_test.go | 188 +++ server/channels/app/integration_action.go | 129 ++- .../channels/app/integration_action_test.go | 1030 ++++++++++++++++- server/channels/app/plugin_api.go | 14 +- server/channels/app/post.go | 47 + server/channels/app/webhook.go | 6 + server/i18n/en.json | 12 + server/public/model/integration_action.go | 154 ++- .../public/model/integration_action_test.go | 740 ++++++++++++ server/public/model/mm_blocks_actions.go | 154 +++ server/public/model/post.go | 16 + server/public/model/post_test.go | 9 +- .../inline_action_button/index.test.tsx | 418 +++++++ .../components/inline_action_button/index.tsx | 214 ++++ .../inline_action_button.scss | 25 + .../src/components/markdown/markdown.tsx | 9 + .../post_markdown/post_markdown.tsx | 11 + webapp/channels/src/i18n/en.json | 2 + .../mattermost-redux/src/actions/posts.ts | 32 + .../src/utils/markdown/renderer.test.tsx | 15 + .../utils/message_html_to_component.test.tsx | 48 + .../src/utils/message_html_to_component.tsx | 24 + webapp/platform/client/src/client4.ts | 7 + 24 files changed, 3290 insertions(+), 62 deletions(-) create mode 100644 server/public/model/mm_blocks_actions.go create mode 100644 webapp/channels/src/components/inline_action_button/index.test.tsx create mode 100644 webapp/channels/src/components/inline_action_button/index.tsx create mode 100644 webapp/channels/src/components/inline_action_button/inline_action_button.scss diff --git a/server/channels/api4/integration_action.go b/server/channels/api4/integration_action.go index c0ab52462a1..70a9f4c0f23 100644 --- a/server/channels/api4/integration_action.go +++ b/server/channels/api4/integration_action.go @@ -5,7 +5,8 @@ package api4 import ( "encoding/json" - "fmt" + "errors" + "io" "net/http" "github.com/mattermost/mattermost/server/public/model" @@ -20,22 +21,6 @@ func (api *API) InitAction() { api.BaseRoutes.APIRoot.Handle("/actions/dialogs/lookup", api.APISessionRequired(lookupDialog)).Methods(http.MethodPost) } -// getStringValue safely converts an interface{} value to a string with logging for failures. -// It handles nil values gracefully and logs warnings when conversion fails. -func getStringValue(val any, fieldName string, logger *mlog.Logger) string { - if val == nil { - return "" - } - if str, ok := val.(string); ok { - return str - } - logger.Warn("Failed to convert field to string", - mlog.String("field", fieldName), - mlog.String("type", fmt.Sprintf("%T", val)), - mlog.Any("value", val)) - return "" -} - func doPostAction(c *Context, w http.ResponseWriter, r *http.Request) { c.RequirePostId() if c.Err != nil { @@ -43,9 +28,26 @@ func doPostAction(c *Context, w http.ResponseWriter, r *http.Request) { } var actionRequest model.DoPostActionRequest - err := json.NewDecoder(r.Body).Decode(&actionRequest) - if err != nil { - c.Logger.Warn("Error decoding the action request", mlog.Err(err)) + dec := json.NewDecoder(r.Body) + err := dec.Decode(&actionRequest) + if err != nil && !errors.Is(err, io.EOF) { + // Empty body is allowed for backward-compatibility with older clients. + // Any other decode failure means the request cannot be trusted — in + // particular, a wrong-type query would otherwise fall through as nil + // and silently execute the action without the caller's params. + c.SetInvalidParamWithErr("action_request", err) + return + } + if err == nil { + // Reject trailing JSON values after the first object (e.g. + // `{"query":{"k":"v"}}{"cookie":"x"}`). json.Decoder.Decode + // stops at the first complete value and would otherwise silently + // ignore the rest, leaving the caller's intent ambiguous. + var trailing any + if extraErr := dec.Decode(&trailing); !errors.Is(extraErr, io.EOF) { + c.SetInvalidParamWithErr("action_request", extraErr) + return + } } var cookie *model.PostActionCookie @@ -82,7 +84,7 @@ func doPostAction(c *Context, w http.ResponseWriter, r *http.Request) { resp := &model.PostActionAPIResponse{Status: "OK"} resp.TriggerId, appErr = c.App.DoPostActionWithCookie(c.AppContext, c.Params.PostId, c.Params.ActionId, c.AppContext.Session().UserId, - actionRequest.SelectedOption, cookie) + actionRequest.SelectedOption, cookie, actionRequest.Query) if appErr != nil { c.Err = appErr return @@ -204,8 +206,8 @@ func lookupDialog(c *Context, w http.ResponseWriter, r *http.Request) { mlog.String("user_id", lookup.UserId), mlog.String("channel_id", lookup.ChannelId), mlog.String("team_id", lookup.TeamId), - mlog.String("selected_field", getStringValue(lookup.Submission["selected_field"], "selected_field", c.Logger)), - mlog.String("query", getStringValue(lookup.Submission["query"], "query", c.Logger)), + mlog.Any("selected_field", lookup.Submission["selected_field"]), + mlog.Any("query", lookup.Submission["query"]), ) resp, err := c.App.LookupInteractiveDialog(c.AppContext, lookup) diff --git a/server/channels/api4/integration_action_test.go b/server/channels/api4/integration_action_test.go index d9b6cd7fd9d..dd61ecc973c 100644 --- a/server/channels/api4/integration_action_test.go +++ b/server/channels/api4/integration_action_test.go @@ -6,9 +6,11 @@ package api4 import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -533,3 +535,189 @@ func TestLookupDialog(t *testing.T) { assert.Empty(t, lookupResp.Items) }) } + +// newAttachmentActionPost posts an attachment action pointing at upstreamURL, +// attributed to th.BasicUser so th.Client has access to call the action. +func newAttachmentActionPost(t *testing.T, th *TestHelper, upstreamURL string) (*model.Post, string) { + t.Helper() + basicPost := &model.Post{ + Message: "attachment action post", + ChannelId: th.BasicChannel.Id, + UserId: th.BasicUser.Id, + Props: model.StringInterface{ + model.PostPropsAttachments: []*model.MessageAttachment{ + { + Text: "hello", + Actions: []*model.PostAction{ + { + Type: model.PostActionTypeButton, + Name: "click", + Integration: &model.PostActionIntegration{ + URL: upstreamURL, + }, + }, + }, + }, + }, + }, + } + created, _, appErr := th.App.CreatePostAsUser(th.Context, basicPost, "", true) + require.Nil(t, appErr) + + attachments, ok := created.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) + require.True(t, ok) + require.NotEmpty(t, attachments) + require.NotEmpty(t, attachments[0].Actions) + require.NotEmpty(t, attachments[0].Actions[0].Id) + return created, attachments[0].Actions[0].Id +} + +func TestDoPostActionQuery_ValidationErrors(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + client := th.Client + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + created, actionID := newAttachmentActionPost(t, th, ts.URL) + route := "/posts/" + created.Id + "/actions/" + actionID + + t.Run("too many entries returns 400 with expected error id", func(t *testing.T) { + ctxMap := make(map[string]string, model.MaxActionQueryEntries+1) + for i := range model.MaxActionQueryEntries + 1 { + ctxMap[fmt.Sprintf("k%d", i)] = "v" + } + payload, err := json.Marshal(model.DoPostActionRequest{Query: ctxMap}) + require.NoError(t, err) + + resp, err := client.DoAPIPost(context.Background(), route, string(payload)) + require.Error(t, err) + CheckBadRequestStatus(t, model.BuildResponse(resp)) + CheckErrorID(t, err, "api.post.do_action.query.app_error") + }) + + t.Run("oversized key returns 400", func(t *testing.T) { + ctxMap := map[string]string{strings.Repeat("k", model.MaxActionQueryKeyLength+1): "v"} + payload, err := json.Marshal(model.DoPostActionRequest{Query: ctxMap}) + require.NoError(t, err) + + resp, err := client.DoAPIPost(context.Background(), route, string(payload)) + require.Error(t, err) + CheckBadRequestStatus(t, model.BuildResponse(resp)) + CheckErrorID(t, err, "api.post.do_action.query.app_error") + }) + + t.Run("oversized value returns 400", func(t *testing.T) { + ctxMap := map[string]string{"k": strings.Repeat("v", model.MaxActionQueryValueLength+1)} + payload, err := json.Marshal(model.DoPostActionRequest{Query: ctxMap}) + require.NoError(t, err) + + resp, err := client.DoAPIPost(context.Background(), route, string(payload)) + require.Error(t, err) + CheckBadRequestStatus(t, model.BuildResponse(resp)) + CheckErrorID(t, err, "api.post.do_action.query.app_error") + }) + + t.Run("small valid context returns 200", func(t *testing.T) { + payload, err := json.Marshal(model.DoPostActionRequest{Query: map[string]string{"tail": "214"}}) + require.NoError(t, err) + + resp, err := client.DoAPIPost(context.Background(), route, string(payload)) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) +} + +func TestDoPostActionQuery_OmitempyCompat(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + client := th.Client + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + created, actionID := newAttachmentActionPost(t, th, ts.URL) + route := "/posts/" + created.Id + "/actions/" + actionID + + // Older clients do not know about query — their request body has no such + // key. The omitempty tag should make this equivalent to sending a nil + // map, which ValidateActionQuery accepts. + payload := `{"selected_option":""}` + resp, err := client.DoAPIPost(context.Background(), route, payload) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Completely empty body should also be accepted — same as older clients + // calling DoPostActionWithCookie with no selection and no cookie. + resp, err = client.DoAPIPost(context.Background(), route, "") + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +// TestDoPostActionMalformedBody verifies non-EOF JSON decode errors now +// return 400 instead of silently running the action with an empty request. +// A body like `{"query":{"k":1}}` (value is not a string) would otherwise +// deserialize to a zero-value Query and skip validation. +func TestDoPostActionMalformedBody(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + client := th.Client + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + created, actionID := newAttachmentActionPost(t, th, ts.URL) + route := "/posts/" + created.Id + "/actions/" + actionID + + t.Run("wrong type for query value returns 400", func(t *testing.T) { + // query must be map[string]string; passing an int value triggers a + // json UnmarshalTypeError which must not fall through. + resp, err := client.DoAPIPost(context.Background(), route, `{"query":{"k":1}}`) + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("syntactically invalid JSON returns 400", func(t *testing.T) { + resp, err := client.DoAPIPost(context.Background(), route, `{not json`) + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("trailing JSON values after the first object return 400", func(t *testing.T) { + // json.Decoder.Decode stops after the first complete value, so a + // body like `{"query":{}}{"cookie":"x"}` would otherwise execute + // the action with the first object's intent and silently drop the + // rest. The handler explicitly rejects trailing values. + resp, err := client.DoAPIPost(context.Background(), route, `{"query":{}}{"cookie":"x"}`) + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) +} diff --git a/server/channels/app/integration_action.go b/server/channels/app/integration_action.go index 5b1ea52a61f..f6c8f935646 100644 --- a/server/channels/app/integration_action.go +++ b/server/channels/app/integration_action.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "maps" "net/http" "net/url" "path" @@ -39,7 +40,57 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/utils" ) -func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError) { +// maxMmBlocksActionsCloneDepth caps recursion in cloneMmBlocksActionsProp. +// ValidateMmBlocksActions bounds top-level entry count and key length but +// does not bound nesting depth inside spec.Context — a bot/plugin could +// otherwise stash a pathologically nested object that drives stack +// exhaustion on the restore path. 64 is well past any plausible legitimate +// nesting; deeper input is treated as malicious and truncated. +const maxMmBlocksActionsCloneDepth = 64 + +// cloneMmBlocksActionsProp deep-clones the post.props.mm_blocks_actions value. +// Each per-action entry can carry nested context / query maps (and arrays +// inside those), so the clone walks the structure recursively — a shallow +// clone at any level would leave nested objects aliased back to the live +// post's props, defeating the restore-after-invalid-response guarantee. +func cloneMmBlocksActionsProp(v any) any { + return cloneMmBlocksActionsPropAt(v, 0) +} + +func cloneMmBlocksActionsPropAt(v any, depth int) any { + if depth > maxMmBlocksActionsCloneDepth { + // Defense-in-depth: drop the subtree rather than risk stack + // exhaustion. The restore path that calls this helper is on a + // rare branch (plugin response is invalid), and pathological + // nesting at this depth is not a legitimate use case. + return nil + } + switch typed := v.(type) { + case map[string]any: + out := make(map[string]any, len(typed)) + for k, child := range typed { + out[k] = cloneMmBlocksActionsPropAt(child, depth+1) + } + return out + case []any: + out := make([]any, len(typed)) + for i, child := range typed { + out[i] = cloneMmBlocksActionsPropAt(child, depth+1) + } + return out + default: + // Scalars (string/number/bool/nil) are immutable — safe to share. + return v + } +} + +func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie, query map[string]string) (string, *model.AppError) { + // Bound the per-click query at the App boundary so any caller — REST + // handler, plugin, future internal trigger — gets the same enforcement. + if err := model.ValidateActionQuery(query); err != nil { + return "", model.NewAppError("DoPostActionWithCookie", "api.post.do_action.query.app_error", nil, "", http.StatusBadRequest).Wrap(err) + } + // PostAction may result in the original post being updated. For the // updated post, we need to unconditionally preserve the original // IsPinned and HasReaction attributes, and preserve its entire @@ -121,10 +172,17 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, upstreamRequest.ChannelName = channel.Name upstreamRequest.TeamId = channel.TeamId upstreamRequest.Type = cookie.Type - upstreamRequest.Context = cookie.Integration.Context + // Clone the Context map — later code may add selected_option to + // it, and we must not mutate the shared source. + // + // query is intentionally not merged on the cookie path: cookies are + // only baked for attachment action buttons, not for mm_blocks + // actions, so this branch is never reached by a click that carries + // per-click query params. + upstreamRequest.Context = maps.Clone(cookie.Integration.Context) datasource = cookie.DataSource - retain = cookie.RetainProps + retain = maps.Clone(cookie.RetainProps) remove = cookie.RemoveProps rootPostId = cookie.RootPostId upstreamURL = cookie.Integration.URL @@ -132,7 +190,7 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, post := result.Data chResult := <-cchan if chResult.NErr != nil { - return "", model.NewAppError("DoPostActionWithCookie", "app.channel.get_for_post.app_error", nil, "", http.StatusInternalServerError).Wrap(result.NErr) + return "", model.NewAppError("DoPostActionWithCookie", "app.channel.get_for_post.app_error", nil, "", http.StatusInternalServerError).Wrap(chResult.NErr) } channel := chResult.Data @@ -145,7 +203,12 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, upstreamRequest.ChannelName = channel.Name upstreamRequest.TeamId = channel.TeamId upstreamRequest.Type = action.Type - upstreamRequest.Context = action.Integration.Context + // Clone the Context map — the action pointer returned from + // post.GetAction may alias post.props state (attachment action) or + // the synthesized mm_blocks_actions spec. Mutating it directly + // would leak per-click values (selected_option) into the post's + // cached integration for subsequent clickers. + upstreamRequest.Context = maps.Clone(action.Integration.Context) datasource = action.DataSource // Save the original values that may need to be preserved (including selected @@ -158,7 +221,10 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, remove = append(remove, key) } } - originalProps = post.GetProps() + // Clone — originalProps may be passed to response.Update.SetProps, + // which would otherwise have response.Update alias the original + // post's props map. + originalProps = maps.Clone(post.GetProps()) originalIsPinned = post.IsPinned originalHasReactions = post.HasReactions @@ -234,6 +300,18 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, return "", model.NewAppError("DoPostActionWithCookie", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err) } + // Merge per-click query into the upstream URL. This is the canonical + // transport for mm_blocks_actions external clicks; for legacy attachment + // clicks `query` is empty so this is a no-op. Done before the request + // log so operators see the URL actually sent on the wire. + if len(query) > 0 { + mergedURL, mergeErr := model.MergeQueryIntoURL(upstreamURL, query) + if mergeErr != nil { + return "", model.NewAppError("DoPostActionWithCookie", "api.post.do_action.merge_query.app_error", nil, "", http.StatusBadRequest).Wrap(mergeErr) + } + upstreamURL = mergedURL + } + // Log request, regardless of whether destination is internal or external rctx.Logger().Info("DoPostActionWithCookie POST request, through DoActionRequest", mlog.String("url", upstreamURL), @@ -281,7 +359,44 @@ func (a *App) DoPostActionWithCookie(rctx request.CTX, postID, actionId, userID, response.Update.IsPinned = originalIsPinned response.Update.HasReactions = originalHasReactions - if _, _, appErr = a.UpdatePost(rctx, response.Update, &model.UpdatePostOptions{SafeUpdate: false}); appErr != nil { + // Validate mm_blocks_actions on update responses. Since + // AllowMmBlocksActionsUpdate bypasses the non-integration guard in + // UpdatePost, and mm_blocks_actions are not in + // PostActionRetainPropKeys, a bad response would otherwise + // permanently replace the post's valid mm_blocks_actions. Keep the + // original value (if any) and log a warning so integration authors + // can diagnose. + // + // Contract (matches the attachments contract): a plugin update + // response that returns a non-nil Props map MUST echo + // mm_blocks_actions back if it wants the buttons to survive. + // Omitting the key drops the prop. This is intentional symmetry + // with attachments and matches the behavior in the mm_blocks + // framework PR. + if response.Update.GetProp(model.PostPropsMmBlocksActions) != nil { + if originalProps[model.PostPropsMmBlocksActions] == nil { + rctx.Logger().Info("Dropping mm_blocks_actions from plugin update response: original post had none", + mlog.String("post_id", postID), + mlog.String("url", upstreamURL), + ) + response.Update.DelProp(model.PostPropsMmBlocksActions) + } else if err := model.ValidateMmBlocksActions(response.Update); err != nil { + rctx.Logger().Info("Restoring original mm_blocks_actions: plugin update response was invalid", + mlog.String("post_id", postID), + mlog.String("url", upstreamURL), + mlog.Err(err), + ) + // originalProps came from maps.Clone(post.GetProps()) + // which is a shallow clone — the nested + // mm_blocks_actions map is still aliased to + // post.Props. Deep-clone before reattaching so a + // later mutation through response.Update can't + // reach back into the original post's prop map. + response.Update.AddProp(model.PostPropsMmBlocksActions, cloneMmBlocksActionsProp(originalProps[model.PostPropsMmBlocksActions])) + } + } + + if _, _, appErr = a.UpdatePost(rctx, response.Update, &model.UpdatePostOptions{SafeUpdate: false, AllowMmBlocksActionsUpdate: true}); appErr != nil { return "", appErr } } diff --git a/server/channels/app/integration_action_test.go b/server/channels/app/integration_action_test.go index 1389b24b18b..68aa21fd51b 100644 --- a/server/channels/app/integration_action_test.go +++ b/server/channels/app/integration_action_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" "strings" "testing" "time" @@ -67,7 +68,7 @@ func TestPostActionInvalidURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.ErrorContains(t, err, "missing protocol scheme") } @@ -119,7 +120,7 @@ func TestPostActionEmptyResponse(t *testing.T) { attachments, ok := post.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) require.True(t, ok) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) }) @@ -167,7 +168,7 @@ func TestPostActionEmptyResponse(t *testing.T) { cfg.ServiceSettings.OutgoingIntegrationRequestsTimeout = new(int64(1)) }) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.ErrorContains(t, err, "context deadline exceeded") }) @@ -236,7 +237,7 @@ func TestPostActionResponseSizeLimit(t *testing.T) { // Should return error due to truncated JSON, but NOT crash or OOM _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, - attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) // Truncated JSON causes unmarshal error assert.Equal(t, "api.post.do_action.action_integration.app_error", err.Id) @@ -279,7 +280,7 @@ func TestPostActionResponseSizeLimit(t *testing.T) { // Should return error due to invalid JSON, but NOT crash or OOM _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, - attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.Equal(t, "api.post.do_action.action_integration.app_error", err.Id) }) @@ -425,16 +426,16 @@ func TestPostAction(t *testing.T) { require.NotEmpty(t, attachments2[0].Actions) require.NotEmpty(t, attachments2[0].Actions[0].Id) - clientTriggerID, err := th.App.DoPostActionWithCookie(th.Context, post.Id, "notavalidid", th.BasicUser.Id, "", nil) + clientTriggerID, err := th.App.DoPostActionWithCookie(th.Context, post.Id, "notavalidid", th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.Equal(t, http.StatusNotFound, err.StatusCode) assert.Len(t, clientTriggerID, 0) - clientTriggerID, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + clientTriggerID, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) assert.Len(t, clientTriggerID, 26) - clientTriggerID, err = th.App.DoPostActionWithCookie(th.Context, post2.Id, attachments2[0].Actions[0].Id, th.BasicUser.Id, "selected", nil) + clientTriggerID, err = th.App.DoPostActionWithCookie(th.Context, post2.Id, attachments2[0].Actions[0].Id, th.BasicUser.Id, "selected", nil, nil) require.Nil(t, err) assert.Len(t, clientTriggerID, 26) @@ -442,7 +443,7 @@ func TestPostAction(t *testing.T) { *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "" }) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.ErrorContains(t, err, "address forbidden") @@ -480,14 +481,14 @@ func TestPostAction(t *testing.T) { attachmentsPlugin, ok := postplugin.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) require.True(t, ok) - _, err = th.App.DoPostActionWithCookie(th.Context, postplugin.Id, attachmentsPlugin[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, postplugin.Id, attachmentsPlugin[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Equal(t, "api.post.do_action.action_integration.app_error", err.Id) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" }) - _, err = th.App.DoPostActionWithCookie(th.Context, postplugin.Id, attachmentsPlugin[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, postplugin.Id, attachmentsPlugin[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) th.App.UpdateConfig(func(cfg *model.Config) { @@ -528,7 +529,7 @@ func TestPostAction(t *testing.T) { attachmentsSiteURL, ok := postSiteURL.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) require.True(t, ok) - _, err = th.App.DoPostActionWithCookie(th.Context, postSiteURL.Id, attachmentsSiteURL[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, postSiteURL.Id, attachmentsSiteURL[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) assert.ErrorContains(t, err, "connection refused") @@ -570,7 +571,7 @@ func TestPostAction(t *testing.T) { attachmentsSubpath, ok := postSubpath.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) require.True(t, ok) - _, err = th.App.DoPostActionWithCookie(th.Context, postSubpath.Id, attachmentsSubpath[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, postSubpath.Id, attachmentsSubpath[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) }) } @@ -644,7 +645,7 @@ func TestPostActionProps(t *testing.T) { attachments, ok := post.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) require.True(t, ok) - clientTriggerId, err := th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + clientTriggerId, err := th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) assert.Len(t, clientTriggerId, 26) @@ -830,7 +831,7 @@ func TestPostActionRelativeURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) @@ -870,7 +871,7 @@ func TestPostActionRelativeURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) @@ -910,7 +911,7 @@ func TestPostActionRelativeURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) @@ -950,7 +951,7 @@ func TestPostActionRelativeURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) @@ -990,7 +991,7 @@ func TestPostActionRelativeURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) } @@ -1067,7 +1068,7 @@ func TestPostActionRelativePluginURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.NotNil(t, err) }) @@ -1107,7 +1108,7 @@ func TestPostActionRelativePluginURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) }) @@ -1147,7 +1148,7 @@ func TestPostActionRelativePluginURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) }) @@ -1187,7 +1188,7 @@ func TestPostActionRelativePluginURL(t *testing.T) { require.NotEmpty(t, attachments[0].Actions) require.NotEmpty(t, attachments[0].Actions[0].Id) - _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil) + _, err = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) require.Nil(t, err) }) } @@ -1757,7 +1758,7 @@ func TestDoPostActionWithCookieEdgeCases(t *testing.T) { }, } - _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", th.BasicUser.Id, "", cookie) + _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", th.BasicUser.Id, "", cookie, nil) require.Nil(t, err) }) @@ -1771,7 +1772,7 @@ func TestDoPostActionWithCookieEdgeCases(t *testing.T) { }, } - _, err := th.App.DoPostActionWithCookie(th.Context, "actual_post_id", "action_id", th.BasicUser.Id, "", cookie) + _, err := th.App.DoPostActionWithCookie(th.Context, "actual_post_id", "action_id", th.BasicUser.Id, "", cookie, nil) require.NotNil(t, err) assert.Contains(t, err.Error(), "postId doesn't match") }) @@ -1784,7 +1785,7 @@ func TestDoPostActionWithCookieEdgeCases(t *testing.T) { Integration: nil, } - _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", th.BasicUser.Id, "", cookie) + _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", th.BasicUser.Id, "", cookie, nil) require.NotNil(t, err) assert.Contains(t, err.Error(), "no Integration in action cookie") }) @@ -1805,10 +1806,129 @@ func TestDoPostActionWithCookieEdgeCases(t *testing.T) { }, } - _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", "nonexistent_user_id", "", cookie) + _, err := th.App.DoPostActionWithCookie(th.Context, "nonexistent_post_id", "action_id", "nonexistent_user_id", "", cookie, nil) require.NotNil(t, err) assert.Contains(t, err.Error(), "Unable to find the user.") }) + + t.Run("rejects oversized query at the App boundary (independent of API handler)", func(t *testing.T) { + // ValidateActionQuery is called at the top of DoPostActionWithCookie, + // not just in the API handler. Direct App-layer callers (plugins, + // tests, internal triggers) get the same enforcement as REST clients. + oversized := make(map[string]string, model.MaxActionQueryEntries+1) + for i := range model.MaxActionQueryEntries + 1 { + oversized["k"+strconv.Itoa(i)] = "v" + } + + _, err := th.App.DoPostActionWithCookie(th.Context, "any_post", "any_action", th.BasicUser.Id, "", nil, oversized) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + assert.Equal(t, "api.post.do_action.query.app_error", err.Id) + }) +} + +// TestCloneMmBlocksActionsProp guards the deep-clone semantics used when +// restoring an original spec after a plugin update response is rejected. +// A shallow clone would alias the nested per-action map back into post.Props, +// so a later mutation through response.Update could reach into the live post. +func TestCloneMmBlocksActionsProp(t *testing.T) { + t.Run("nil and non-map values are returned unchanged", func(t *testing.T) { + assert.Nil(t, cloneMmBlocksActionsProp(nil)) + assert.Equal(t, "string", cloneMmBlocksActionsProp("string")) + }) + + t.Run("top-level and nested mutations on the clone do not leak", func(t *testing.T) { + original := map[string]any{ + "btn1": map[string]any{ + "type": "external", + "url": "http://example.com/hook", + }, + } + + cloned, ok := cloneMmBlocksActionsProp(original).(map[string]any) + require.True(t, ok) + + // Mutating the top-level map on the clone (adding a key) must not + // reach the original. + cloned["btn2"] = map[string]any{"type": "external", "url": "http://example.com/other"} + assert.NotContains(t, original, "btn2") + + // Mutating a nested per-action map on the clone (changing the URL) + // must not reach the original — this is the case the shallow-clone + // bug actually exposed. + clonedEntry, ok := cloned["btn1"].(map[string]any) + require.True(t, ok) + clonedEntry["url"] = "http://attacker.example/" + + originalEntry, ok := original["btn1"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "http://example.com/hook", originalEntry["url"]) + }) + + t.Run("deeply nested context and array mutations on the clone do not leak", func(t *testing.T) { + // Per-action specs can carry nested context maps and arrays. A + // shallow per-entry clone would still alias these structures back + // to the live post's props. + original := map[string]any{ + "btn1": map[string]any{ + "type": "external", + "url": "http://example.com/hook", + "context": map[string]any{"team": "alpha", "tags": []any{"a", "b"}}, + }, + } + + cloned, ok := cloneMmBlocksActionsProp(original).(map[string]any) + require.True(t, ok) + + clonedEntry := cloned["btn1"].(map[string]any) + clonedContext := clonedEntry["context"].(map[string]any) + + // Mutate the nested context map on the clone. + clonedContext["team"] = "tampered" + clonedContext["new"] = "added" + + // Mutate the nested array on the clone. + clonedTags := clonedContext["tags"].([]any) + clonedTags[0] = "tampered" + + // Original must be untouched at every level. + originalEntry := original["btn1"].(map[string]any) + originalContext := originalEntry["context"].(map[string]any) + assert.Equal(t, "alpha", originalContext["team"]) + assert.NotContains(t, originalContext, "new") + assert.Equal(t, []any{"a", "b"}, originalContext["tags"]) + }) + + t.Run("pathologically nested input is truncated past maxMmBlocksActionsCloneDepth", func(t *testing.T) { + // ValidateMmBlocksActions doesn't bound nesting depth inside + // spec.Context — defense-in-depth against stack exhaustion if a + // bot/plugin author crafts deeply nested input. + var leaf any = "leaf" + const tooDeep = maxMmBlocksActionsCloneDepth + 100 + for range tooDeep { + leaf = map[string]any{"n": leaf} + } + + // Must not stack-overflow / panic. + var cloned any + require.NotPanics(t, func() { + cloned = cloneMmBlocksActionsProp(leaf) + }) + + // Walk the clone; should hit nil before reaching the leaf string. + current := cloned + for i := range tooDeep { + m, ok := current.(map[string]any) + if !ok { + assert.Greater(t, i, maxMmBlocksActionsCloneDepth-2, + "truncation should kick in at or near maxMmBlocksActionsCloneDepth") + assert.Nil(t, current, "subtree past depth cap must be nil, not aliased to source") + return + } + current = m["n"] + } + t.Fatalf("clone walked %d levels without hitting truncation", tooDeep) + }) } func TestDoPluginRequest(t *testing.T) { @@ -2002,3 +2122,859 @@ func TestDoPluginRequest(t *testing.T) { } }) } + +// buildMmBlocksActionsProp returns a mm_blocks_actions map (an "external"-type +// action) suitable for use as a post prop in tests. +func buildMmBlocksActionsProp(id, url string, context map[string]any) map[string]any { + entry := map[string]any{ + "type": model.MmBlocksActionTypeExternal, + "url": url, + } + if context != nil { + entry["context"] = context + } + return map[string]any{id: entry} +} + +// setupBotInChannel creates a bot, joins it to the team and channel, and +// returns the resolved *model.User for the bot. +func setupBotInChannel(t *testing.T, th *TestHelper) *model.User { + t.Helper() + bot := th.CreateBot(t) + botUser, appErr := th.App.GetUser(bot.UserId) + require.Nil(t, appErr) + _, _, appErr = th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, botUser.Id, "") + require.Nil(t, appErr) + _, appErr = th.App.AddUserToChannel(th.Context, botUser, th.BasicChannel, false) + require.Nil(t, appErr) + return botUser +} + +func TestMmBlocksActionsStrippedOnCreate(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + post := &model.Post{ + Message: "hello with inline actions", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: th.BasicUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "actionone", + "http://127.0.0.1/plugins/myplugin/doit", + map[string]any{"operation": "STORM"}, + ), + }, + } + + created, _, err := th.App.CreatePostAsUser(th.Context, post, "", true) + require.Nil(t, err) + assert.Nil(t, created.GetProp(model.PostPropsMmBlocksActions), "non-bot, non-integration user should have mm_blocks_actions stripped") + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + assert.Nil(t, stored.GetProp(model.PostPropsMmBlocksActions), "stored post should not carry mm_blocks_actions") +} + +func TestMmBlocksActionsKeptForBotIntegration(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + + // IsOAuth=true makes Session.IsIntegration() return true without needing + // a full bot-token session. + intSession := &model.Session{UserId: botUser.Id, IsOAuth: true} + intCtx := th.Context.WithSession(intSession) + + post := &model.Post{ + Message: "hello from a bot", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "actiontwo", + "http://127.0.0.1/plugins/myplugin/doit", + map[string]any{"operation": "STORM"}, + ), + }, + } + + created, _, err := th.App.CreatePostAsUser(intCtx, post, "", true) + require.Nil(t, err) + require.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions), "bot post via integration session should preserve mm_blocks_actions") + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + require.NotNil(t, stored.GetProp(model.PostPropsMmBlocksActions), "stored bot post should carry mm_blocks_actions") + + spec := stored.GetMmBlocksActionSpec("actiontwo") + require.NotNil(t, spec) + assert.Equal(t, "http://127.0.0.1/plugins/myplugin/doit", spec.URL) +} + +// TestPluginAPICreatePostKeepsMmBlocksActions locks the contract that a +// plugin creating a post via PluginAPI.CreatePost retains mm_blocks_actions. +// Plugins are server-trusted code, but their static activation-time rctx +// has an unmarked session — without pluginIntegrationCtx the strip in +// CreatePost would delete the prop and clicks would 404 with +// "invalid action id". +func TestPluginAPICreatePostKeepsMmBlocksActions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + botUser := setupBotInChannel(t, th) + + manifest := &model.Manifest{Id: "com.mattermost.test-plugin"} + api := NewPluginAPI(th.App, th.Context, manifest) + + post := &model.Post{ + ChannelId: th.BasicChannel.Id, + UserId: botUser.Id, + Message: "issue tracker post", + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "triage", + "/plugins/com.mattermost.test-plugin/inline_action/triage", + map[string]any{"project": "Demo Project"}, + ), + }, + } + + created, appErr := api.CreatePost(post) + require.Nil(t, appErr) + require.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions), + "plugin-created post must preserve mm_blocks_actions; the strip in CreatePost should not fire because PluginAPI marks the session as integration") + + // Re-read from the store to confirm persistence (not just in-memory). + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + spec := stored.GetMmBlocksActionSpec("triage") + require.NotNil(t, spec, "stored plugin post must resolve the action spec at click time") + assert.Equal(t, "/plugins/com.mattermost.test-plugin/inline_action/triage", spec.URL) +} + +// TestMmBlocksActionsKeptForWebhookImpersonation verifies that an integration +// session is sufficient on its own — the post's author does not need to be a +// bot. This is the webhook-impersonation flow: a webhook posts as a regular +// user with from_webhook=true, and we must not strip the prop just because +// user.IsBot is false. +func TestMmBlocksActionsKeptForWebhookImpersonation(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + // Integration session for a regular (non-bot) user. + intSession := &model.Session{UserId: th.BasicUser.Id, IsOAuth: true} + intCtx := th.Context.WithSession(intSession) + + post := &model.Post{ + Message: "post from impersonating webhook", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: th.BasicUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "webhook1", + "http://127.0.0.1/plugins/myplugin/wh", + nil, + ), + }, + } + + created, _, err := th.App.CreatePostAsUser(intCtx, post, "", true) + require.Nil(t, err) + require.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions), + "non-bot author via integration session must preserve mm_blocks_actions (webhook flow)") + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + require.NotNil(t, stored.GetProp(model.PostPropsMmBlocksActions)) +} + +// TestMmBlocksActionsStripGate locks the create-time strip policy: keep +// when the post is bot-authored OR the session is an integration; strip +// when neither signal is present. The bot-author signal covers +// PluginAPI.CreatePost (whose static rctx is unmarked) where the post is +// authored by the plugin's bot user; the integration-session signal +// covers REST callers using bot tokens, PATs, or OAuth apps. +func TestMmBlocksActionsStripGate(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + + inline := buildMmBlocksActionsProp( + "mx", + "http://127.0.0.1/plugins/myplugin/mx", + nil, + ) + + t.Run("bot author via non-integration session is kept", func(t *testing.T) { + // Models the PluginAPI.CreatePost path: post.UserId is the plugin's + // bot user but rctx.Session() is the unmarked plugin context. The + // bot-author signal alone must be sufficient to keep the prop. + post := &model.Post{ + Message: "hello", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{model.PostPropsMmBlocksActions: inline}, + } + created, _, err := th.App.CreatePostAsUser(th.Context, post, "", true) + require.Nil(t, err) + assert.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions), + "bot-authored post must keep mm_blocks_actions even without an integration session") + }) + + t.Run("regular user via non-integration session is stripped", func(t *testing.T) { + // Neither signal present: the prop must be removed. Catches the + // baseline user-content case. + post := &model.Post{ + Message: "hello", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: th.BasicUser.Id, + Props: model.StringInterface{model.PostPropsMmBlocksActions: inline}, + } + created, _, err := th.App.CreatePostAsUser(th.Context, post, "", true) + require.Nil(t, err) + assert.Nil(t, created.GetProp(model.PostPropsMmBlocksActions), + "regular-user post via non-integration session must strip mm_blocks_actions") + }) +} + +func TestUpdatePostMmBlocksActionsGuard(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + + // Bot posts with mm_blocks_actions must be CREATED via an integration + // session — see the matching create-time strip in CreatePostAsUser. + intSeedSession := &model.Session{UserId: botUser.Id, IsOAuth: true} + intSeedCtx := th.Context.WithSession(intSeedSession) + + // originalInline is the mm_blocks_actions value we expect the bot post to + // keep after non-integration edits. + originalInline := buildMmBlocksActionsProp( + "keep", + "http://127.0.0.1/plugins/myplugin/original", + map[string]any{"k": "orig"}, + ) + + t.Run("non-integration edit of bot post reverts mm_blocks_actions", func(t *testing.T) { + botPost := &model.Post{ + Message: "bot post with inline actions", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: originalInline, + }, + } + created, _, cErr := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, cErr) + require.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions)) + + // A non-integration session tries to swap mm_blocks_actions wholesale. + newInline := buildMmBlocksActionsProp( + "swap", + "http://127.0.0.1/plugins/myplugin/swapped", + map[string]any{"k": "attacker"}, + ) + edit := created.Clone() + edit.Message = "edited message" + edit.AddProp(model.PostPropsMmBlocksActions, newInline) + + // th.Context has an empty/zero session — not an integration. + updated, _, uErr := th.App.UpdatePost(th.Context, edit, &model.UpdatePostOptions{SafeUpdate: false}) + require.Nil(t, uErr) + + // mm_blocks_actions should revert to the original value. + got := updated.GetMmBlocksActionSpec("keep") + require.NotNil(t, got, "original inline action should still be reachable") + assert.Equal(t, "http://127.0.0.1/plugins/myplugin/original", got.URL) + + // The attacker's swapped action should not be present. + assert.Nil(t, updated.GetMmBlocksActionSpec("swap")) + + // Message change should still be applied. + assert.Equal(t, "edited message", updated.Message) + }) + + t.Run("non-integration edit cannot add mm_blocks_actions when original had none", func(t *testing.T) { + plainBotPost := &model.Post{ + Message: "bot post without inline actions", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + } + created, _, cErr := th.App.CreatePostAsUser(intSeedCtx, plainBotPost, "", true) + require.Nil(t, cErr) + require.Nil(t, created.GetProp(model.PostPropsMmBlocksActions)) + + newInline := buildMmBlocksActionsProp( + "added", + "http://127.0.0.1/plugins/myplugin/added", + nil, + ) + edit := created.Clone() + edit.AddProp(model.PostPropsMmBlocksActions, newInline) + + updated, _, uErr := th.App.UpdatePost(th.Context, edit, &model.UpdatePostOptions{SafeUpdate: false}) + require.Nil(t, uErr) + assert.Nil(t, updated.GetProp(model.PostPropsMmBlocksActions), "non-integration update must not introduce mm_blocks_actions") + }) + + t.Run("integration session alone cannot modify mm_blocks_actions", func(t *testing.T) { + // Even with an integration session (PAT / OAuth / bot-token), the + // UpdatePost path requires AllowMmBlocksActionsUpdate to modify + // mm_blocks_actions. A PAT-holding user could otherwise inject + // mm_blocks_actions on any post they can edit. + botPost := &model.Post{ + Message: "bot post for integration edit", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: originalInline, + }, + } + created, _, cErr := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, cErr) + + intSession := &model.Session{UserId: th.BasicUser.Id, IsOAuth: true} + intCtx := th.Context.WithSession(intSession) + require.True(t, intCtx.Session().IsIntegration()) + + newInline := buildMmBlocksActionsProp( + "replaced", + "http://127.0.0.1/plugins/myplugin/new", + map[string]any{"k": "integration"}, + ) + edit := created.Clone() + edit.AddProp(model.PostPropsMmBlocksActions, newInline) + + updated, _, uErr := th.App.UpdatePost(intCtx, edit, &model.UpdatePostOptions{SafeUpdate: false}) + require.Nil(t, uErr) + + // The attacker's "replaced" entry must not land; the original stays. + assert.Nil(t, updated.GetMmBlocksActionSpec("replaced"), "integration session alone must not overwrite mm_blocks_actions") + keep := updated.GetMmBlocksActionSpec("keep") + require.NotNil(t, keep, "original inline action must be preserved") + assert.Equal(t, "http://127.0.0.1/plugins/myplugin/original", keep.URL) + }) + + t.Run("AllowMmBlocksActionsUpdate option accepts new mm_blocks_actions", func(t *testing.T) { + botPost := &model.Post{ + Message: "bot post for plugin-path edit", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: originalInline, + }, + } + created, _, cErr := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, cErr) + + newInline := buildMmBlocksActionsProp( + "plugin", + "http://127.0.0.1/plugins/myplugin/plugin", + map[string]any{"k": "plugin"}, + ) + edit := created.Clone() + edit.AddProp(model.PostPropsMmBlocksActions, newInline) + + // Non-integration session, but AllowMmBlocksActionsUpdate grants write. + updated, _, uErr := th.App.UpdatePost(th.Context, edit, &model.UpdatePostOptions{SafeUpdate: false, AllowMmBlocksActionsUpdate: true}) + require.Nil(t, uErr) + + assert.Nil(t, updated.GetMmBlocksActionSpec("keep")) + integration := updated.GetMmBlocksActionSpec("plugin") + require.NotNil(t, integration) + assert.Equal(t, "http://127.0.0.1/plugins/myplugin/plugin", integration.URL) + }) +} + +// TestCreateWebhookPostStripsMmBlocksActions locks the contract that an +// incoming webhook cannot persist mm_blocks_actions even if the payload +// includes the prop in its `props` map. CreateWebhookPost's prop iteration +// has no explicit blocklist entry for mm_blocks_actions; it falls through +// to AddProp and would land on the post object. The strip in CreatePost +// (post.go) then fires because the webhook flow has no integration session +// (incomingWebhook is registered with RequireSession: false). If a future +// refactor changes the webhook session model, this test catches it. +func TestCreateWebhookPostStripsMmBlocksActions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableIncomingWebhooks = true }) + + hook, hookErr := th.App.CreateIncomingWebhookForChannel(th.BasicUser.Id, th.BasicChannel, &model.IncomingWebhook{ChannelId: th.BasicChannel.Id}) + require.Nil(t, hookErr) + defer func() { + _ = th.App.DeleteIncomingWebhook(hook.Id) + }() + + inline := buildMmBlocksActionsProp( + "actx", + "http://127.0.0.1/plugins/myplugin/x", + nil, + ) + + post, appErr := th.App.CreateWebhookPost(th.Context, hook.UserId, th.BasicChannel, "hello", "user", "http://iconurl", "", + model.StringInterface{ + model.PostPropsMmBlocksActions: inline, + }, + "", "", nil) + require.Nil(t, appErr) + + assert.Nil(t, post.GetProp(model.PostPropsMmBlocksActions), + "incoming webhook payload must not be able to persist mm_blocks_actions; the strip in CreatePost should fire because the webhook session has IsIntegration()==false") + + // Belt and suspenders: read back from the DB to confirm the prop is + // not persisted either. + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, post.Id, false) + require.NoError(t, nErr) + assert.Nil(t, stored.GetProp(model.PostPropsMmBlocksActions), + "stored webhook post must not carry mm_blocks_actions") +} + +func TestSendEphemeralPostStripsMmBlocksActions(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + ephemeral := &model.Post{ + ChannelId: th.BasicChannel.Id, + UserId: th.BasicUser.Id, + Message: "ephemeral with inline actions", + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "eph", + "http://127.0.0.1/plugins/myplugin/eph", + map[string]any{"k": "v"}, + ), + }, + } + + result, _ := th.App.SendEphemeralPost(th.Context, th.BasicUser.Id, ephemeral) + require.NotNil(t, result) + assert.Nil(t, result.GetProp(model.PostPropsMmBlocksActions), "SendEphemeralPost must drop mm_blocks_actions") + + // UpdateEphemeralPost path + ephemeral2 := &model.Post{ + Id: result.Id, + ChannelId: th.BasicChannel.Id, + UserId: th.BasicUser.Id, + Message: "updated ephemeral with inline actions", + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: buildMmBlocksActionsProp( + "eph2", + "http://127.0.0.1/plugins/myplugin/eph2", + nil, + ), + }, + } + updated, _ := th.App.UpdateEphemeralPost(th.Context, th.BasicUser.Id, ephemeral2) + require.NotNil(t, updated) + assert.Nil(t, updated.GetProp(model.PostPropsMmBlocksActions), "UpdateEphemeralPost must drop mm_blocks_actions") +} + +func TestDoPostActionQueryMergedIntoURL(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + intSeedCtx := th.Context.WithSession(&model.Session{UserId: botUser.Id, IsOAuth: true}) + + // Capture both the upstream integration request body and the URL the + // server saw, so we can assert that per-click query lands in the URL + // (mm_blocks transport) and not in the upstream Context map. + var ( + capturedReq model.PostActionIntegrationRequest + capturedRawQuery string + ) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedRawQuery = r.URL.RawQuery + body, readErr := io.ReadAll(r.Body) + require.NoError(t, readErr) + require.NoError(t, json.Unmarshal(body, &capturedReq)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + inlineActions := buildMmBlocksActionsProp( + "inline1", + ts.URL, + map[string]any{"operation": "STORM"}, + ) + botPost := &model.Post{ + Message: "mm_blocks action post", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: inlineActions, + }, + } + created, _, err := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, err) + require.NotNil(t, created.GetProp(model.PostPropsMmBlocksActions)) + + query := map[string]string{"tail": "214"} + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, "inline1", th.BasicUser.Id, "", nil, query) + require.Nil(t, err) + + // Query was appended to the upstream URL. + parsedQuery, qErr := url.ParseQuery(capturedRawQuery) + require.NoError(t, qErr) + assert.Equal(t, "214", parsedQuery.Get("tail"), "per-click query should land in the upstream URL") + + // Original action Context is forwarded as the upstream request's + // Context, untouched by the query merge. + assert.Equal(t, "STORM", capturedReq.Context["operation"]) + _, leakedInlineParams := capturedReq.Context["inline_params"] + assert.False(t, leakedInlineParams, "query must not be injected into upstream Context") +} + +func TestDoPostActionStaticQueryMergedWithPerClickQuery(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + intSeedCtx := th.Context.WithSession(&model.Session{UserId: botUser.Id, IsOAuth: true}) + + var capturedRawQuery string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedRawQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + // Spec carries a static query (source=fleet) AND a key (tail=999) that + // the per-click query will override. Per-click should win. + botPost := &model.Post{ + Message: "mm_blocks action post with static query", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: map[string]any{ + "inline1": map[string]any{ + "type": model.MmBlocksActionTypeExternal, + "url": ts.URL, + "query": map[string]any{"source": "fleet", "tail": "999"}, + }, + }, + }, + } + created, _, err := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, err) + + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, "inline1", th.BasicUser.Id, "", nil, map[string]string{"tail": "214"}) + require.Nil(t, err) + + parsedQuery, qErr := url.ParseQuery(capturedRawQuery) + require.NoError(t, qErr) + assert.Equal(t, "fleet", parsedQuery.Get("source"), "spec static query should land in the upstream URL") + assert.Equal(t, "214", parsedQuery.Get("tail"), "per-click query should override spec static query on overlapping keys") +} + +func TestDoPostActionContextMapNotMutated(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + intSeedCtx := th.Context.WithSession(&model.Session{UserId: botUser.Id, IsOAuth: true}) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + originalContext := map[string]any{"operation": "STORM"} + inlineActions := buildMmBlocksActionsProp("inline1", ts.URL, originalContext) + botPost := &model.Post{ + Message: "mm_blocks action post", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsMmBlocksActions: inlineActions, + }, + } + created, _, err := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, err) + + // First click: carries one set of per-click query values. + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, "inline1", th.BasicUser.Id, "", nil, map[string]string{"tail": "214"}) + require.Nil(t, err) + + // Post's stored mm_blocks_actions Context must not be mutated by the click. + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + spec := stored.GetMmBlocksActionSpec("inline1") + require.NotNil(t, spec) + assert.Equal(t, "STORM", spec.Context["operation"]) + assert.Equal(t, ts.URL, spec.URL, "stored URL must not absorb per-click query") + + // Second click with a different per-click query. + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, "inline1", th.BasicUser.Id, "", nil, map[string]string{"tail": "999"}) + require.Nil(t, err) + + stored, nErr = th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + spec = stored.GetMmBlocksActionSpec("inline1") + require.NotNil(t, spec) + assert.Equal(t, "STORM", spec.Context["operation"]) + assert.Equal(t, ts.URL, spec.URL, "stored URL must not absorb per-click query") +} + +func TestDoPostActionPluginResponseMmBlocksActionsDropped(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + + // Plugin returns an update that tries to add mm_blocks_actions, even + // though the original post had none. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := `{ + "update": { + "message": "updated message", + "props": { + "mm_blocks_actions": { + "sneaky": {"type": "external", "url": "http://127.0.0.1/plugins/myplugin/sneak"} + } + } + } + }` + _, _ = w.Write([]byte(resp)) + })) + defer ts.Close() + + // Bot post has an ATTACHMENT action (not an mm_blocks action), and no + // mm_blocks_actions prop. The plugin's response to clicking the + // attachment should not be able to introduce mm_blocks_actions. + botPost := &model.Post{ + Message: "attachment-only bot post", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsAttachments: []*model.MessageAttachment{ + { + Text: "hello", + Actions: []*model.PostAction{ + { + Type: model.PostActionTypeButton, + Name: "click", + Integration: &model.PostActionIntegration{ + URL: ts.URL, + }, + }, + }, + }, + }, + }, + } + created, _, err := th.App.CreatePostAsUser(th.Context, botPost, "", true) + require.Nil(t, err) + attachments, ok := created.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) + require.True(t, ok) + require.NotEmpty(t, attachments[0].Actions) + require.NotEmpty(t, attachments[0].Actions[0].Id) + require.Nil(t, created.GetProp(model.PostPropsMmBlocksActions)) + + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) + require.Nil(t, err) + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + assert.Nil(t, stored.GetProp(model.PostPropsMmBlocksActions), "plugin response must not be able to add mm_blocks_actions where none existed") + assert.Equal(t, "updated message", stored.Message) +} + +func TestDoPostActionPluginResponseInvalidMmBlocksActionsRestored(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + botUser := setupBotInChannel(t, th) + intSeedCtx := th.Context.WithSession(&model.Session{UserId: botUser.Id, IsOAuth: true}) + + // Plugin returns an update where mm_blocks_actions contains an entry + // with an empty URL — invalid; the original prop should be restored + // with a warning, while the message update still succeeds. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := `{ + "update": { + "message": "updated via plugin", + "props": { + "mm_blocks_actions": { + "broken": {"type": "external", "url": ""} + } + } + } + }` + _, _ = w.Write([]byte(resp)) + })) + defer ts.Close() + + // The original post has VALID mm_blocks_actions, so the "drop because + // original had none" branch is bypassed and we exercise the validation + // branch. + originalInline := buildMmBlocksActionsProp( + "orig", + "http://127.0.0.1/plugins/myplugin/orig", + nil, + ) + botPost := &model.Post{ + Message: "bot post with valid inline actions", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: botUser.Id, + Props: model.StringInterface{ + model.PostPropsAttachments: []*model.MessageAttachment{ + { + Text: "hello", + Actions: []*model.PostAction{ + { + Type: model.PostActionTypeButton, + Name: "click", + Integration: &model.PostActionIntegration{ + URL: ts.URL, + }, + }, + }, + }, + }, + model.PostPropsMmBlocksActions: originalInline, + }, + } + created, _, err := th.App.CreatePostAsUser(intSeedCtx, botPost, "", true) + require.Nil(t, err) + attachments, ok := created.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) + require.True(t, ok) + require.NotEmpty(t, attachments[0].Actions) + require.NotEmpty(t, attachments[0].Actions[0].Id) + + _, err = th.App.DoPostActionWithCookie(th.Context, created.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) + require.Nil(t, err) + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, created.Id, false) + require.NoError(t, nErr) + // Message update still applied — the invalid mm_blocks_actions were + // restored to the original value with a warning, so the rest of the + // response.Update is persisted. + assert.Equal(t, "updated via plugin", stored.Message) + // The broken action from the plugin response must never be stored. + assert.Nil(t, stored.GetMmBlocksActionSpec("broken"), "invalid mm_blocks action from plugin response must not be persisted") + // The original valid mm_blocks_actions must survive — an invalid plugin + // response must never wipe a post's existing buttons. + require.NotNil(t, stored.GetMmBlocksActionSpec("orig"), "original valid mm_blocks action must be preserved when plugin response is invalid") + assert.Equal(t, "http://127.0.0.1/plugins/myplugin/orig", stored.GetMmBlocksActionSpec("orig").URL) +} + +// TestPostActionRetainsFromBotAndFromPlugin verifies that from_bot and +// from_plugin props are retained across a plugin-returned post update even +// when the plugin's response.Props omits them. This matters because the +// webapp's allowInlineActions gate is derived from these markers; losing +// them on first update would hide every inline button on subsequent renders. +func TestPostActionRetainsFromBotAndFromPlugin(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowedUntrustedInternalConnections = "localhost,127.0.0.1" + }) + + // Plugin response deliberately omits from_bot / from_plugin from props. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"update": {"message": "updated", "props": {"A": "AA"}}}`) + })) + defer ts.Close() + + interactivePost := model.Post{ + Message: "interactive", + ChannelId: th.BasicChannel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: th.BasicUser.Id, + Props: model.StringInterface{ + model.PostPropsAttachments: []*model.MessageAttachment{{ + Text: "hello", + Actions: []*model.PostAction{{ + Type: model.PostActionTypeButton, + Name: "click", + Integration: &model.PostActionIntegration{ + URL: ts.URL, + }, + }}, + }}, + model.PostPropsFromBot: "true", + model.PostPropsFromPlugin: "true", + }, + } + + post, _, appErr := th.App.CreatePostAsUser(th.Context, &interactivePost, "", true) + require.Nil(t, appErr) + attachments, ok := post.GetProp(model.PostPropsAttachments).([]*model.MessageAttachment) + require.True(t, ok) + + _, appErr = th.App.DoPostActionWithCookie(th.Context, post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id, "", nil, nil) + require.Nil(t, appErr) + + stored, nErr := th.App.Srv().Store().Post().GetSingle(th.Context, post.Id, false) + require.NoError(t, nErr) + + assert.Equal(t, "true", stored.GetProp(model.PostPropsFromBot), "from_bot must be retained across plugin update response") + assert.Equal(t, "true", stored.GetProp(model.PostPropsFromPlugin), "from_plugin must be retained across plugin update response") + assert.Equal(t, "AA", stored.GetProp("A"), "plugin-supplied prop applied") +} diff --git a/server/channels/app/plugin_api.go b/server/channels/app/plugin_api.go index fa6dad2bec4..b57b3a2ef2f 100644 --- a/server/channels/app/plugin_api.go +++ b/server/channels/app/plugin_api.go @@ -874,7 +874,19 @@ func (api *PluginAPI) GetPostsForChannel(channelID string, page, perPage int) (* } func (api *PluginAPI) UpdatePost(post *model.Post) (*model.Post, *model.AppError) { - post, _, appErr := api.app.UpdatePost(api.ctx, post, &model.UpdatePostOptions{SafeUpdate: false}) + // Grant mm_blocks_actions write access only when the plugin's update + // actually includes the prop, AND the value passes validation. + // Otherwise the freeze in UpdatePost preserves whatever the original + // post had — plugins that update unrelated fields don't accidentally + // drop or corrupt mm_blocks_actions. + allowMmBlocksActionsUpdate := false + if post.GetProp(model.PostPropsMmBlocksActions) != nil { + if err := model.ValidateMmBlocksActions(post); err != nil { + return nil, model.NewAppError("UpdatePost", "plugin.api.update_post.mm_blocks_actions.app_error", nil, "", http.StatusBadRequest).Wrap(err) + } + allowMmBlocksActionsUpdate = true + } + post, _, appErr := api.app.UpdatePost(api.ctx, post, &model.UpdatePostOptions{SafeUpdate: false, AllowMmBlocksActionsUpdate: allowMmBlocksActionsUpdate}) if post != nil { post = post.ForPlugin() } diff --git a/server/channels/app/post.go b/server/channels/app/post.go index ff61bd923c3..71ede54b87a 100644 --- a/server/channels/app/post.go +++ b/server/channels/app/post.go @@ -255,6 +255,24 @@ func (a *App) CreatePost(rctx request.CTX, post *model.Post, channel *model.Chan post.AddProp(model.PostPropsFromOAuthApp, "true") } + // Strip mm_blocks_actions from posts that are neither bot-authored nor + // created via an integration session. Either signal is sufficient: + // - user.IsBot (DB-verified) covers PluginAPI.CreatePost where the + // plugin's static rctx has no integration markers but the post + // is authored by a bot user. + // - rctx.Session().IsIntegration() (server-derived, unspoofable) + // covers REST callers using bot tokens, PATs, or OAuth apps. + // + // Webhooks are handled separately at their entry point + // (CreateWebhookPost) — webhook payloads are user-controlled even + // when bound to a bot user, so the prop is dropped before the post + // reaches CreatePost. See TestCreateWebhookPostStripsMmBlocksActions. + if post.GetProp(model.PostPropsMmBlocksActions) != nil { + if !user.IsBot && !rctx.Session().IsIntegration() { + post.DelProp(model.PostPropsMmBlocksActions) + } + } + var ephemeralPost *model.Post if post.Type == "" { if hasPermission, _ := a.HasPermissionToChannel(rctx, user.Id, channel.Id, model.PermissionUseChannelMentions); !hasPermission { @@ -710,6 +728,13 @@ func (a *App) SendEphemeralPost(rctx request.CTX, userID string, post *model.Pos post.SetProps(make(model.StringInterface)) } + // mm_blocks_actions cannot be resolved on click for ephemeral posts (no + // DB row, no per-action cookie transport). Drop the prop here so the + // client doesn't render a non-functional button. + if post.GetProp(model.PostPropsMmBlocksActions) != nil { + post.DelProp(model.PostPropsMmBlocksActions) + } + post.GenerateActionIds() message := model.NewWebSocketEvent(model.WebsocketEventEphemeralMessage, "", post.ChannelId, userID, nil, "") post = a.PreparePostForClientWithEmbedsAndImages(rctx, post, &model.PreparePostForClientOpts{IsNewPost: true, IncludePriority: true}) @@ -744,6 +769,13 @@ func (a *App) UpdateEphemeralPost(rctx request.CTX, userID string, post *model.P post.SetProps(make(model.StringInterface)) } + // mm_blocks_actions cannot be resolved on click for ephemeral posts (no + // DB row, no per-action cookie transport). Drop the prop here so the + // client doesn't render a non-functional button. + if post.GetProp(model.PostPropsMmBlocksActions) != nil { + post.DelProp(model.PostPropsMmBlocksActions) + } + post.GenerateActionIds() message := model.NewWebSocketEvent(model.WebsocketEventPostEdited, "", post.ChannelId, userID, nil, "") post = a.PreparePostForClientWithEmbedsAndImages(rctx, post, &model.PreparePostForClientOpts{IsNewPost: true, IncludePriority: true}) @@ -862,6 +894,21 @@ func (a *App) UpdatePost(rctx request.CTX, receivedUpdatedPost *model.Post, upda newPost.HasReactions = receivedUpdatedPost.HasReactions newPost.SetProps(receivedUpdatedPost.GetProps()) + // mm_blocks_actions can only be modified by trusted paths that have + // pre-validated the new value (AllowMmBlocksActionsUpdate). Session + // type is intentionally not a sufficient signal: a PAT/OAuth session + // from a regular user would otherwise bypass the freeze and inject + // mm_blocks_actions on edit, since from_bot on the original post is + // user-forgeable. All other callers keep whatever mm_blocks_actions + // the original post had (or none). + if !updatePostOptions.AllowMmBlocksActionsUpdate { + if oldVal, ok := oldPost.GetProps()[model.PostPropsMmBlocksActions]; ok { + newPost.AddProp(model.PostPropsMmBlocksActions, oldVal) + } else { + newPost.DelProp(model.PostPropsMmBlocksActions) + } + } + var fileIds []string fileIds, appErr = a.processPostFileChanges(rctx, receivedUpdatedPost, oldPost, updatePostOptions) if appErr != nil { diff --git a/server/channels/app/webhook.go b/server/channels/app/webhook.go index 92775e70d50..68f5655c04e 100644 --- a/server/channels/app/webhook.go +++ b/server/channels/app/webhook.go @@ -367,6 +367,12 @@ func (a *App) CreateWebhookPost(rctx request.CTX, userID string, channel *model. model.PostPropsOverrideUsername, model.PostPropsFromWebhook: // Do nothing + case model.PostPropsMmBlocksActions: + // Webhook payloads are user-controlled even when the + // webhook is bound to a bot user, so the bot-author + // signal in CreatePost's strip rule cannot distinguish + // them. Drop here so mm_blocks_actions never reaches + // the post object. default: post.AddProp(key, val) } diff --git a/server/i18n/en.json b/server/i18n/en.json index 1f58855805c..862b1b859ab 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -2940,6 +2940,14 @@ "id": "api.post.do_action.action_integration.app_error", "translation": "Action integration error." }, + { + "id": "api.post.do_action.merge_query.app_error", + "translation": "Failed to merge query into action URL." + }, + { + "id": "api.post.do_action.query.app_error", + "translation": "Invalid action query." + }, { "id": "api.post.error_get_post_id.pending", "translation": "Unable to get the pending post." @@ -12866,6 +12874,10 @@ "id": "plugin.api.get_users_in_channel", "translation": "Unable to get the users, invalid sorting criteria." }, + { + "id": "plugin.api.update_post.mm_blocks_actions.app_error", + "translation": "Invalid mm_blocks_actions in plugin post update." + }, { "id": "plugin.api.update_user_status.bad_status", "translation": "Unable to set the user status. Unknown user status." diff --git a/server/public/model/integration_action.go b/server/public/model/integration_action.go index a8e00646442..03e8875cb8b 100644 --- a/server/public/model/integration_action.go +++ b/server/public/model/integration_action.go @@ -16,7 +16,9 @@ import ( "io" "math/big" "net/http" + "net/url" "reflect" + "regexp" "slices" "strconv" "strings" @@ -55,16 +57,30 @@ var commonDateTimeFormats = []string{ ISODateTimeNoSecondsFormat, // ISO datetime without seconds } -var PostActionRetainPropKeys = []string{PostPropsFromWebhook, PostPropsOverrideUsername, PostPropsOverrideIconURL} +var PostActionRetainPropKeys = []string{ + PostPropsFromWebhook, + PostPropsFromBot, + PostPropsFromPlugin, + PostPropsOverrideUsername, + PostPropsOverrideIconURL, +} type DoPostActionRequest struct { - SelectedOption string `json:"selected_option,omitempty"` - Cookie string `json:"cookie,omitempty"` + SelectedOption string `json:"selected_option,omitempty"` + Cookie string `json:"cookie,omitempty"` + Query map[string]string `json:"query,omitempty"` } const ( PostActionDataSourceUsers = "users" PostActionDataSourceChannels = "channels" + + MaxMmBlocksActionsPerPost = 50 + MaxMmBlocksActionKeyLength = 64 + + MaxActionQueryEntries = 50 + MaxActionQueryKeyLength = 128 + MaxActionQueryValueLength = 2048 ) type PostAction struct { @@ -873,6 +889,7 @@ func (o *Post) StripActionIntegrations() { action.Integration = nil } } + o.StripMmBlocksActionSecrets() } func (o *Post) GetAction(id string) *PostAction { @@ -883,6 +900,137 @@ func (o *Post) GetAction(id string) *PostAction { } } } + if spec := o.GetMmBlocksActionSpec(id); spec != nil && spec.Type == MmBlocksActionTypeExternal && spec.URL != "" { + // Synthesize a PostAction so the existing click pipeline can + // dispatch without branching on action source. Pre-merge the + // spec's static per-action query into the URL here; per-click + // query (from DoPostActionRequest.Query) is merged on top by the + // caller via MergeQueryIntoURL, with per-click overriding static + // values on overlapping keys. + url := spec.URL + if len(spec.Query) > 0 { + merged, err := MergeQueryIntoURL(spec.URL, spec.Query) + if err != nil { + // Spec URL is malformed. ValidateMmBlocksActions + // should have rejected it at save time, so this is a + // belt-and-suspenders guard. Returning nil routes the + // caller through the standard "action not found" + // 404 path rather than firing a request to a URL + // that's missing the static query params. + return nil + } + url = merged + } + return &PostAction{ + Id: id, + Type: PostActionTypeButton, + Integration: &PostActionIntegration{ + URL: url, + Context: spec.Context, + }, + } + } + return nil +} + +var mmBlocksActionIDRegex = regexp.MustCompile(`^[A-Za-z0-9]+$`) + +// ValidateMmBlocksActions verifies the post's mm_blocks_actions prop has the +// expected shape and bounds. Each entry must coerce to a valid spec via +// mmBlocksEntryMapToSpec. +func ValidateMmBlocksActions(o *Post) error { + raw := o.GetProp(PostPropsMmBlocksActions) + if raw == nil { + return nil + } + actions, ok := coerceToStringAnyMap(raw) + if !ok { + return fmt.Errorf("mm_blocks_actions must be a map") + } + if len(actions) > MaxMmBlocksActionsPerPost { + return fmt.Errorf("mm_blocks_actions exceeds maximum of %d entries", MaxMmBlocksActionsPerPost) + } + for key, entry := range actions { + if len(key) > MaxMmBlocksActionKeyLength { + return fmt.Errorf("mm_blocks_actions key exceeds %d chars", MaxMmBlocksActionKeyLength) + } + if !mmBlocksActionIDRegex.MatchString(key) { + return fmt.Errorf("mm_blocks_actions key %q must be alphanumeric", key) + } + entryMap, ok := coerceToStringAnyMap(entry) + if !ok { + return fmt.Errorf("mm_blocks_actions entry %q must be an object", key) + } + spec := mmBlocksEntryMapToSpec(entryMap) + if spec == nil { + return fmt.Errorf("mm_blocks_actions entry %q has invalid type or shape", key) + } + if spec.Type == MmBlocksActionTypeExternal { + if err := validateIntegrationURL(spec.URL); err != nil { + return fmt.Errorf("mm_blocks_actions entry %q: %w", key, err) + } + // Bound the per-spec static query so a bot cannot stash + // unbounded data in the post that gets merged into the + // outgoing URL on every click. + if err := ValidateActionQuery(spec.Query); err != nil { + return fmt.Errorf("mm_blocks_actions entry %q static query: %w", key, err) + } + // Bound entry count and key length on the static context. + // Values are arbitrary JSON, so size is constrained by the + // outer post-size limit; we cap entries to prevent crafted + // posts from inflating GetAction's clone cost. + if len(spec.Context) > MaxActionQueryEntries { + return fmt.Errorf("mm_blocks_actions entry %q context exceeds maximum of %d entries", key, MaxActionQueryEntries) + } + for k := range spec.Context { + if len(k) > MaxActionQueryKeyLength { + return fmt.Errorf("mm_blocks_actions entry %q context key exceeds %d chars", key, MaxActionQueryKeyLength) + } + } + } + } + return nil +} + +// ValidateActionQuery bounds the size of user-supplied per-click query +// parameters so a crafted post cannot trigger unbounded memory use in the +// plugin-request path. +func ValidateActionQuery(q map[string]string) error { + if len(q) > MaxActionQueryEntries { + return fmt.Errorf("query exceeds maximum of %d entries", MaxActionQueryEntries) + } + for key, value := range q { + if len(key) > MaxActionQueryKeyLength { + return fmt.Errorf("query key exceeds %d chars", MaxActionQueryKeyLength) + } + if len(value) > MaxActionQueryValueLength { + return fmt.Errorf("query value for %q exceeds %d chars", key, MaxActionQueryValueLength) + } + } + return nil +} + +func validateIntegrationURL(rawURL string) error { + if rawURL == "" { + return fmt.Errorf("must have a non-empty URL") + } + if !(strings.HasPrefix(rawURL, "/plugins/") || strings.HasPrefix(rawURL, "plugins/") || IsValidHTTPURL(rawURL)) { + return fmt.Errorf("must have a valid integration URL") + } + // Reject path-traversal segments. /plugins/ URLs are routed by the + // local server, so a `..` segment can escape the plugin namespace and + // hit unrelated server routes. url.Parse decodes percent-encoded path + // bytes into u.Path, which is the same single decode pass that + // doPluginRequest performs at dispatch — so encoded forms like + // %2e%2e%2f are caught here symmetrically with how the router would + // resolve them. + u, parseErr := url.Parse(rawURL) + if parseErr != nil { + return fmt.Errorf("must have a valid integration URL: %w", parseErr) + } + if strings.Contains(u.Path, "/../") || strings.HasSuffix(u.Path, "/..") { + return fmt.Errorf("integration URL must not contain path traversal segments") + } return nil } diff --git a/server/public/model/integration_action_test.go b/server/public/model/integration_action_test.go index baefc9f968b..69a751982d2 100644 --- a/server/public/model/integration_action_test.go +++ b/server/public/model/integration_action_test.go @@ -12,6 +12,7 @@ import ( "encoding/json" "io" "math/big" + "strconv" "strings" "testing" "time" @@ -1676,3 +1677,742 @@ func TestDialogElementDateTimeValidation(t *testing.T) { assert.True(t, effective.ManualTimeEntry, "deprecated field alone should enable manual entry after EffectiveDateTimeConfig") }) } + +func TestValidateActionQuery(t *testing.T) { + t.Run("nil map is valid", func(t *testing.T) { + assert.NoError(t, ValidateActionQuery(nil)) + }) + + t.Run("empty map is valid", func(t *testing.T) { + assert.NoError(t, ValidateActionQuery(map[string]string{})) + }) + + t.Run("within bounds is valid", func(t *testing.T) { + ctx := map[string]string{ + "alpha": "one", + "beta": "two", + } + assert.NoError(t, ValidateActionQuery(ctx)) + }) + + t.Run("exceeds MaxActionQueryEntries", func(t *testing.T) { + ctx := make(map[string]string, MaxActionQueryEntries+1) + for i := range MaxActionQueryEntries + 1 { + ctx[strconv.Itoa(i)] = "v" + } + err := ValidateActionQuery(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + }) + + t.Run("key length exactly MaxActionQueryKeyLength is allowed", func(t *testing.T) { + ctx := map[string]string{ + strings.Repeat("k", MaxActionQueryKeyLength): "value", + } + assert.NoError(t, ValidateActionQuery(ctx)) + }) + + t.Run("key length MaxActionQueryKeyLength+1 is rejected", func(t *testing.T) { + ctx := map[string]string{ + strings.Repeat("k", MaxActionQueryKeyLength+1): "value", + } + err := ValidateActionQuery(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "key exceeds") + }) + + t.Run("value length exactly MaxActionQueryValueLength is allowed", func(t *testing.T) { + ctx := map[string]string{ + "key": strings.Repeat("v", MaxActionQueryValueLength), + } + assert.NoError(t, ValidateActionQuery(ctx)) + }) + + t.Run("value length MaxActionQueryValueLength+1 is rejected", func(t *testing.T) { + ctx := map[string]string{ + "key": strings.Repeat("v", MaxActionQueryValueLength+1), + } + err := ValidateActionQuery(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "value for") + }) + + t.Run("multiple violations triggers an error", func(t *testing.T) { + // Too many entries AND every value is over-length. First detected + // violation wins; only assert that an error is returned. + ctx := make(map[string]string, MaxActionQueryEntries+1) + for i := range MaxActionQueryEntries + 1 { + ctx[strconv.Itoa(i)] = strings.Repeat("v", MaxActionQueryValueLength+1) + } + err := ValidateActionQuery(ctx) + require.Error(t, err) + }) +} + +func mmBlocksExternalEntry(url string, context map[string]any) map[string]any { + entry := map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": url, + } + if context != nil { + entry["context"] = context + } + return entry +} + +func TestGetMmBlocksActionSpec(t *testing.T) { + t.Run("prop absent returns nil", func(t *testing.T) { + p := &Post{} + assert.Nil(t, p.GetMmBlocksActionSpec("btn1")) + }) + + t.Run("empty action id returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + assert.Nil(t, p.GetMmBlocksActionSpec("")) + }) + + t.Run("id not found returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + assert.Nil(t, p.GetMmBlocksActionSpec("missing")) + }) + + t.Run("external entry returns spec with url and context", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", map[string]any{"k": "v"}), + }) + got := p.GetMmBlocksActionSpec("btn1") + require.NotNil(t, got) + assert.Equal(t, MmBlocksActionTypeExternal, got.Type) + assert.Equal(t, "http://example.com/hook", got.URL) + assert.Equal(t, "v", got.Context["k"]) + }) + + t.Run("entry missing type returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{"url": "http://example.com/hook"}, + }) + assert.Nil(t, p.GetMmBlocksActionSpec("btn1")) + }) + + t.Run("entry with unknown type returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": "bogus", + "url": "http://example.com/hook", + }, + }) + assert.Nil(t, p.GetMmBlocksActionSpec("btn1")) + }) + + t.Run("wrong-shape prop returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, "not-a-map") + assert.Nil(t, p.GetMmBlocksActionSpec("btn1")) + }) + + t.Run("entry value not an object returns nil", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": "not-an-object", + }) + assert.Nil(t, p.GetMmBlocksActionSpec("btn1")) + }) +} + +func TestValidateMmBlocksActions(t *testing.T) { + t.Run("absent prop returns no error", func(t *testing.T) { + p := &Post{} + assert.NoError(t, ValidateMmBlocksActions(p)) + }) + + t.Run("string prop is rejected (cookie transport not yet supported)", func(t *testing.T) { + // The cookie-transport PR will add proper validation for + // encrypted-string payloads. Until then, any string value is + // rejected so an integration session cannot bypass the + // alphanumeric-key, URL, and bounds checks by simply storing a + // raw string at the prop key. + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, "encrypted-cookie-blob") + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be a map") + }) + + t.Run("valid external entries return no error", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + "btn2": mmBlocksExternalEntry("/plugins/myplugin/action", nil), + "btn3": mmBlocksExternalEntry("plugins/myplugin/action", nil), + }) + assert.NoError(t, ValidateMmBlocksActions(p)) + }) + + t.Run("exceeding MaxMmBlocksActionsPerPost returns error", func(t *testing.T) { + actions := make(map[string]any, MaxMmBlocksActionsPerPost+1) + for i := range MaxMmBlocksActionsPerPost + 1 { + actions["btn"+strconv.Itoa(i)] = mmBlocksExternalEntry("http://example.com/hook", nil) + } + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, actions) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + }) + + t.Run("action id with hyphen is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "foo-bar": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be alphanumeric") + }) + + t.Run("action id at MaxMmBlocksActionKeyLength is allowed", func(t *testing.T) { + key := strings.Repeat("a", MaxMmBlocksActionKeyLength) + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + key: mmBlocksExternalEntry("http://example.com/hook", nil), + }) + assert.NoError(t, ValidateMmBlocksActions(p)) + }) + + t.Run("action id over MaxMmBlocksActionKeyLength is rejected", func(t *testing.T) { + key := strings.Repeat("a", MaxMmBlocksActionKeyLength+1) + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + key: mmBlocksExternalEntry("http://example.com/hook", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds") + }) + + t.Run("action id with underscore is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "foo_bar": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be alphanumeric") + }) + + t.Run("action id with space is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "FOO bar": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be alphanumeric") + }) + + t.Run("empty URL is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "non-empty URL") + }) + + t.Run("path traversal in /plugins/ URL is rejected", func(t *testing.T) { + // Defense-in-depth: a `..` segment in a /plugins/ URL can escape the + // plugin namespace at request time. Bot-authored mm_blocks specs are + // the origin point so we reject at save. + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("/plugins/../../../etc/passwd", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "path traversal") + }) + + t.Run("trailing /.. in /plugins/ URL is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("/plugins/myplugin/..", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "path traversal") + }) + + t.Run("percent-encoded traversal in /plugins/ URL is rejected", func(t *testing.T) { + // doPluginRequest decodes the path via url.Parse before path.Clean, + // so an encoded "%2e%2e%2f" would otherwise route to a different + // plugin than the validator thinks it's protecting. Validator must + // decode symmetrically to catch this at save time. + for _, encoded := range []string{ + "/plugins/innocent/%2e%2e%2f/target/handler", + "/plugins/innocent/%2E%2E%2F/target/handler", + "/plugins/innocent/..%2f/target/handler", + "/plugins/innocent/%2e%2e/", + "/plugins/innocent/%2e%2e", + } { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry(encoded, nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err, "url=%q must be rejected", encoded) + assert.Contains(t, err.Error(), "path traversal", "url=%q", encoded) + } + }) + + t.Run("entry missing type is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{"url": "http://example.com/hook"}, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid type or shape") + }) + + t.Run("entry with unknown type is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": "bogus", + "url": "http://example.com/hook", + }, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid type or shape") + }) + + t.Run("entry value not an object is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": "not-an-object", + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be an object") + }) + + t.Run("javascript URL is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("javascript://alert(1)", nil), + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "valid integration URL") + }) + + t.Run("http URL is accepted", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://legit.com", nil), + }) + assert.NoError(t, ValidateMmBlocksActions(p)) + }) + + t.Run("/plugins/ URL is accepted", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("/plugins/foo", nil), + }) + assert.NoError(t, ValidateMmBlocksActions(p)) + }) + + t.Run("wrong-shape raw prop is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, []string{"not-a-map"}) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be a map") + }) + + t.Run("static query exceeding entry cap is rejected", func(t *testing.T) { + query := make(map[string]any, MaxActionQueryEntries+1) + for i := range MaxActionQueryEntries + 1 { + query["k"+strconv.Itoa(i)] = "v" + } + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + "query": query, + }, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "static query") + }) + + t.Run("static query value exceeding length cap is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + "query": map[string]any{"k": strings.Repeat("a", MaxActionQueryValueLength+1)}, + }, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "static query") + }) + + t.Run("static context exceeding entry cap is rejected", func(t *testing.T) { + ctx := make(map[string]any, MaxActionQueryEntries+1) + for i := range MaxActionQueryEntries + 1 { + ctx["k"+strconv.Itoa(i)] = "v" + } + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + "context": ctx, + }, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "context exceeds maximum") + }) + + t.Run("static context key exceeding length cap is rejected", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + "context": map[string]any{strings.Repeat("a", MaxActionQueryKeyLength+1): "v"}, + }, + }) + err := ValidateMmBlocksActions(p) + require.Error(t, err) + assert.Contains(t, err.Error(), "context key exceeds") + }) +} + +func TestStripActionIntegrations_MmBlocksActions(t *testing.T) { + t.Run("strips mm_blocks_actions prop", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + p.StripActionIntegrations() + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + }) + + t.Run("post without mm_blocks_actions prop does not panic", func(t *testing.T) { + p := &Post{} + assert.NotPanics(t, func() { + p.StripActionIntegrations() + }) + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + }) + + t.Run("post with both attachments and mm_blocks_actions cleans both", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsAttachments, []*MessageAttachment{ + { + Actions: []*PostAction{ + { + Id: "a1", + Name: "Button", + Type: PostActionTypeButton, + Integration: &PostActionIntegration{URL: "http://example.com/hook"}, + }, + }, + }, + }) + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + + p.StripActionIntegrations() + + // mm_blocks_actions prop should be removed entirely. + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + + // Attachment actions should remain but with nil Integration. + attachments := p.Attachments() + require.Len(t, attachments, 1) + require.Len(t, attachments[0].Actions, 1) + assert.Nil(t, attachments[0].Actions[0].Integration) + }) +} + +func TestGetAction_MmBlocksFallback(t *testing.T) { + t.Run("returns attachment action when present", func(t *testing.T) { + attachmentAction := &PostAction{ + Id: "a1", + Name: "Attach Button", + Type: PostActionTypeButton, + Integration: &PostActionIntegration{URL: "http://example.com/attach"}, + } + p := &Post{} + p.AddProp(PostPropsAttachments, []*MessageAttachment{ + {Actions: []*PostAction{attachmentAction}}, + }) + + got := p.GetAction("a1") + require.NotNil(t, got) + assert.Same(t, attachmentAction, got) + }) + + t.Run("synthesizes PostAction from mm_blocks_actions when no attachment match", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", map[string]any{"k": "v"}), + }) + + got := p.GetAction("btn1") + require.NotNil(t, got) + assert.Equal(t, "btn1", got.Id) + assert.Equal(t, PostActionTypeButton, got.Type) + require.NotNil(t, got.Integration) + assert.Equal(t, "http://example.com/hook", got.Integration.URL) + assert.Equal(t, "v", got.Integration.Context["k"]) + }) + + t.Run("synthesized URL pre-merges spec static query", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + "query": map[string]any{"source": "fleet-status"}, + }, + }) + + got := p.GetAction("btn1") + require.NotNil(t, got) + require.NotNil(t, got.Integration) + assert.Equal(t, "http://example.com/hook?source=fleet-status", got.Integration.URL) + }) + + t.Run("synthesized URL preserves existing query and adds spec static query", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook?team=alpha", + "query": map[string]any{"source": "fleet-status"}, + }, + }) + + got := p.GetAction("btn1") + require.NotNil(t, got) + require.NotNil(t, got.Integration) + // url.Values.Encode() sorts keys alphabetically. + assert.Contains(t, got.Integration.URL, "source=fleet-status") + assert.Contains(t, got.Integration.URL, "team=alpha") + }) + + t.Run("attachment wins when id matches both attachment and mm_blocks action", func(t *testing.T) { + attachmentAction := &PostAction{ + Id: "btn1", + Name: "Attach Button", + Type: PostActionTypeButton, + Integration: &PostActionIntegration{URL: "http://example.com/attach"}, + } + p := &Post{} + p.AddProp(PostPropsAttachments, []*MessageAttachment{ + {Actions: []*PostAction{attachmentAction}}, + }) + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/inline", nil), + }) + + got := p.GetAction("btn1") + require.NotNil(t, got) + assert.Same(t, attachmentAction, got) + assert.Equal(t, "http://example.com/attach", got.Integration.URL) + }) + + t.Run("returns nil when id matches neither", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsAttachments, []*MessageAttachment{ + {Actions: []*PostAction{{Id: "other", Name: "X", Type: PostActionTypeButton, Integration: &PostActionIntegration{URL: "http://example.com"}}}}, + }) + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "something": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + + assert.Nil(t, p.GetAction("missing")) + }) + + t.Run("returns nil when spec URL is unparseable and static query merge fails", func(t *testing.T) { + // Defense-in-depth: ValidateMmBlocksActions should reject this at + // save time, but if a malformed URL slips through, GetAction must + // not silently fire the bare URL with the static query dropped. + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/%%%bad", + "query": map[string]any{"source": "fleet"}, + }, + }) + + assert.Nil(t, p.GetAction("btn1")) + }) +} + +func TestMergeQueryIntoURL(t *testing.T) { + t.Run("empty query returns rawURL unchanged", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook", nil) + require.NoError(t, err) + assert.Equal(t, "http://example.com/hook", got) + + got, err = MergeQueryIntoURL("http://example.com/hook", map[string]string{}) + require.NoError(t, err) + assert.Equal(t, "http://example.com/hook", got) + }) + + t.Run("URL without existing query gets query appended", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook", map[string]string{"a": "1"}) + require.NoError(t, err) + assert.Equal(t, "http://example.com/hook?a=1", got) + }) + + t.Run("URL with existing query merges non-overlapping keys", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook?team=alpha", map[string]string{"source": "fleet"}) + require.NoError(t, err) + // Encode() sorts keys alphabetically. + assert.Contains(t, got, "team=alpha") + assert.Contains(t, got, "source=fleet") + }) + + t.Run("query map overrides existing key on overlap", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook?tail=999", map[string]string{"tail": "214"}) + require.NoError(t, err) + assert.Equal(t, "http://example.com/hook?tail=214", got) + }) + + t.Run("URL fragment is preserved", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook#anchor", map[string]string{"a": "1"}) + require.NoError(t, err) + assert.Equal(t, "http://example.com/hook?a=1#anchor", got) + }) + + t.Run("special characters in values are URL-encoded", func(t *testing.T) { + got, err := MergeQueryIntoURL("http://example.com/hook", map[string]string{"q": "a b&c=d"}) + require.NoError(t, err) + // space → +, & and = → %26 / %3D + assert.Contains(t, got, "q=a+b%26c%3Dd") + }) + + t.Run("relative URL with empty path accepts query merge", func(t *testing.T) { + got, err := MergeQueryIntoURL("/plugins/myplugin/action", map[string]string{"a": "1"}) + require.NoError(t, err) + assert.Equal(t, "/plugins/myplugin/action?a=1", got) + }) + + t.Run("malformed URL returns parse error", func(t *testing.T) { + _, err := MergeQueryIntoURL("://not-a-url", map[string]string{"a": "1"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "parse url") + }) +} + +func TestMmBlocksContextMap(t *testing.T) { + t.Run("empty string returns nil", func(t *testing.T) { + assert.Nil(t, MmBlocksContextMap("")) + }) + + t.Run("valid JSON object string is parsed into a map", func(t *testing.T) { + got := MmBlocksContextMap(`{"k":"v","n":1}`) + require.NotNil(t, got) + assert.Equal(t, "v", got["k"]) + // JSON numbers decode to float64. + assert.Equal(t, float64(1), got["n"]) + }) + + t.Run("non-JSON string is wrapped under context key", func(t *testing.T) { + got := MmBlocksContextMap("hello world") + require.NotNil(t, got) + assert.Equal(t, "hello world", got["context"]) + }) + + t.Run("JSON null falls back to wrap (m is nil after unmarshal)", func(t *testing.T) { + got := MmBlocksContextMap("null") + require.NotNil(t, got) + assert.Equal(t, "null", got["context"]) + }) + + t.Run("JSON array falls back to wrap (target type mismatch)", func(t *testing.T) { + got := MmBlocksContextMap("[1,2,3]") + require.NotNil(t, got) + assert.Equal(t, "[1,2,3]", got["context"]) + }) + + t.Run("JSON number falls back to wrap (target type mismatch)", func(t *testing.T) { + got := MmBlocksContextMap("42") + require.NotNil(t, got) + assert.Equal(t, "42", got["context"]) + }) + + t.Run("malformed JSON falls back to wrap", func(t *testing.T) { + got := MmBlocksContextMap(`{"unclosed":`) + require.NotNil(t, got) + assert.Equal(t, `{"unclosed":`, got["context"]) + }) +} + +func TestStripMmBlocksActionSecrets(t *testing.T) { + t.Run("absent prop is a no-op", func(t *testing.T) { + p := &Post{} + assert.NotPanics(t, func() { + p.StripMmBlocksActionSecrets() + }) + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + }) + + t.Run("map-form prop is deleted", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + p.StripMmBlocksActionSecrets() + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + }) + + t.Run("string-form prop is deleted (cookie transport not yet supported)", func(t *testing.T) { + // Until the cookie-transport PR ships proper handling, any string + // value is treated as opaque garbage and stripped wholesale — + // matches the validator's reject-strings policy. + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, "encrypted-cookie-blob") + p.StripMmBlocksActionSecrets() + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + }) + + t.Run("other props on the post are not touched", func(t *testing.T) { + p := &Post{} + p.AddProp(PostPropsMmBlocksActions, map[string]any{ + "btn1": mmBlocksExternalEntry("http://example.com/hook", nil), + }) + p.AddProp(PostPropsAttachments, []*MessageAttachment{{Text: "keep me"}}) + p.AddProp(PostPropsFromBot, "true") + + p.StripMmBlocksActionSecrets() + + assert.Nil(t, p.GetProp(PostPropsMmBlocksActions)) + assert.NotNil(t, p.GetProp(PostPropsAttachments)) + assert.Equal(t, "true", p.GetProp(PostPropsFromBot)) + }) +} diff --git a/server/public/model/mm_blocks_actions.go b/server/public/model/mm_blocks_actions.go new file mode 100644 index 00000000000..dbea4c869aa --- /dev/null +++ b/server/public/model/mm_blocks_actions.go @@ -0,0 +1,154 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +// Server-side definitions for the post.props.mm_blocks_actions registry that +// underpins the markdown-actions feature. Mirrors the canonical model +// landing in the broader mm_blocks framework PR; cookie transport +// (MmBlocksActionCookie, AddMmBlocksActionCookies, ParseDecryptedActionCookiePayload) +// is intentionally omitted here and will be filled in by that PR. Until then, +// mm_blocks_actions is resolved on click via DB lookup +// (GetMmBlocksActionSpec) and stripped from ephemeral broadcasts so dead +// buttons don't render. + +package model + +import ( + "encoding/json" + "fmt" + "maps" + "net/url" +) + +const ( + MmBlocksActionTypeExternal = "external" +) + +// MmBlocksActionSpec is the server-side definition for one entry in props.mm_blocks_actions. +type MmBlocksActionSpec struct { + Type string + URL string + Query map[string]string + Context map[string]any +} + +// GetMmBlocksActionSpec returns the action definition for actionID from props.mm_blocks_actions, if present. +func (o *Post) GetMmBlocksActionSpec(actionID string) *MmBlocksActionSpec { + raw := o.GetProp(PostPropsMmBlocksActions) + if raw == nil || actionID == "" { + return nil + } + actionsTop, ok := coerceToStringAnyMap(raw) + if !ok { + return nil + } + entry, ok := actionsTop[actionID] + if !ok || entry == nil { + return nil + } + entryMap, ok := coerceToStringAnyMap(entry) + if !ok { + return nil + } + return mmBlocksEntryMapToSpec(entryMap) +} + +// mmBlocksEntryMapToSpec maps one props.mm_blocks_actions[actionID] object to MmBlocksActionSpec. +func mmBlocksEntryMapToSpec(entryMap map[string]any) *MmBlocksActionSpec { + typ, _ := entryMap["type"].(string) + if typ == "" { + return nil + } + if typ != MmBlocksActionTypeExternal { + return nil + } + spec := &MmBlocksActionSpec{Type: typ} + spec.URL, _ = entryMap["url"].(string) + spec.Context = contextMapFromProp(entryMap["context"]) + spec.Query = stringMapFromPropValue(entryMap["query"]) + return spec +} + +// MmBlocksContextMap parses a context JSON string or treats a non-JSON string as a single context value. +func MmBlocksContextMap(contextString string) map[string]any { + if contextString == "" { + return nil + } + var m map[string]any + if err := json.Unmarshal([]byte(contextString), &m); err == nil && m != nil { + return m + } + return map[string]any{"context": contextString} +} + +// MergeQueryIntoURL merges q into rawURL's query string; existing keys are overwritten by q. +func MergeQueryIntoURL(rawURL string, q map[string]string) (string, error) { + if len(q) == 0 { + return rawURL, nil + } + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("parse url: %w", err) + } + values := u.Query() + for k, v := range q { + values.Set(k, v) + } + u.RawQuery = values.Encode() + return u.String(), nil +} + +// StripMmBlocksActionSecrets removes server-only fields from +// props.mm_blocks_actions for wire serialization. The current +// implementation deletes the prop wholesale; the cookie-transport PR will +// extend this to preserve encrypted-string cookie payloads in place. +func (o *Post) StripMmBlocksActionSecrets() { + if o.GetProp(PostPropsMmBlocksActions) == nil { + return + } + o.DelProp(PostPropsMmBlocksActions) +} + +// contextMapFromProp normalizes props.mm_blocks_actions[*].context to map[string]any (JSON object or string). +func contextMapFromProp(v any) map[string]any { + if v == nil { + return nil + } + if s, ok := v.(string); ok { + return MmBlocksContextMap(s) + } + if m, ok := coerceToStringAnyMap(v); ok { + // Clone so callers cannot mutate the live post.Props map. A + // nested mutation through the returned map would otherwise race + // with concurrent post.Props readers. + return maps.Clone(m) + } + return nil +} + +func stringMapFromPropValue(v any) map[string]string { + m, ok := coerceToStringAnyMap(v) + if !ok || len(m) == 0 { + return nil + } + out := make(map[string]string, len(m)) + for k, val := range m { + if s, ok := val.(string); ok { + out[k] = s + } + } + if len(out) == 0 { + return nil + } + return out +} + +func coerceToStringAnyMap(v any) (map[string]any, bool) { + if v == nil { + return nil, false + } + m, ok := v.(map[string]any) + if ok { + return m, true + } + return nil, false +} diff --git a/server/public/model/post.go b/server/public/model/post.go index 3c1d52cb9ae..50f41586673 100644 --- a/server/public/model/post.go +++ b/server/public/model/post.go @@ -93,6 +93,7 @@ const ( PostPropsFromOAuthApp = "from_oauth_app" PostPropsWebhookDisplayName = "webhook_display_name" PostPropsAttachments = "attachments" + PostPropsMmBlocksActions = "mm_blocks_actions" PostPropsFromPlugin = "from_plugin" PostPropsMentionHighlightDisabled = "mentionHighlightDisabled" PostPropsGroupHighlightDisabled = "disable_group_highlight" @@ -619,6 +620,7 @@ func ContainsIntegrationsReservedProps(props StringInterface) []string { PostPropsWebhookDisplayName, PostPropsOverrideIconURL, PostPropsOverrideIconEmoji, + PostPropsMmBlocksActions, } for _, key := range reservedProps { @@ -843,6 +845,12 @@ func (o *Post) propsIsValid() error { } } + if props[PostPropsMmBlocksActions] != nil { + if err := ValidateMmBlocksActions(o); err != nil { + multiErr = multierror.Append(multiErr, fmt.Errorf("invalid mm_blocks_actions: %w", err)) + } + } + for i, a := range o.Attachments() { if err := a.IsValid(); err != nil { multiErr = multierror.Append(multiErr, multierror.Prefix(err, fmt.Sprintf("message attachtment at index %d is invalid:", i))) @@ -1197,6 +1205,14 @@ func (o *Post) CleanPost() *Post { type UpdatePostOptions struct { SafeUpdate bool IsRestorePost bool + + // AllowMmBlocksActionsUpdate grants the caller permission to add, + // remove, or modify the mm_blocks_actions prop. Without it, + // non-integration sessions cannot change mm_blocks_actions and the + // prop is reset to its prior value. Set only from trusted paths (e.g. + // the post-action integration response handler which has already + // validated the incoming value). + AllowMmBlocksActionsUpdate bool } func DefaultUpdatePostOptions() *UpdatePostOptions { diff --git a/server/public/model/post_test.go b/server/public/model/post_test.go index 1d0b17b2d73..62739b184b9 100644 --- a/server/public/model/post_test.go +++ b/server/public/model/post_test.go @@ -171,10 +171,17 @@ func TestPost_ContainsIntegrationsReservedProps(t *testing.T) { PostPropsOverrideUsername: "overridden_username", PostPropsOverrideIconURL: "a-custom-url", PostPropsOverrideIconEmoji: ":custom_emoji_name:", + PostPropsMmBlocksActions: map[string]any{ + "btn1": map[string]any{ + "type": MmBlocksActionTypeExternal, + "url": "http://example.com/hook", + }, + }, }, } keys2 := post2.ContainsIntegrationsReservedProps() - require.Len(t, keys2, 5) + require.Len(t, keys2, 6) + require.Contains(t, keys2, PostPropsMmBlocksActions) } func TestPostPatch_ContainsIntegrationsReservedProps(t *testing.T) { diff --git a/webapp/channels/src/components/inline_action_button/index.test.tsx b/webapp/channels/src/components/inline_action_button/index.test.tsx new file mode 100644 index 00000000000..bed76138270 --- /dev/null +++ b/webapp/channels/src/components/inline_action_button/index.test.tsx @@ -0,0 +1,418 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import React from 'react'; + +import {doPostActionWithQuery} from 'mattermost-redux/actions/posts'; + +import {act, fireEvent, renderWithContext, screen, userEvent} from 'tests/react_testing_utils'; + +import InlineActionButton from './index'; + +jest.mock('mattermost-redux/actions/posts', () => ({ + doPostActionWithQuery: jest.fn(), +})); + +const mockedDoPostActionWithQuery = doPostActionWithQuery as jest.MockedFunction; + +/** + * Creates a thunk-shaped mock whose inner promise is externally controllable. + * The thunk returned by `doPostActionWithQuery` is invoked by redux-thunk + * middleware; the returned promise is what the component awaits. + */ +function setupControllablePromise() { + let resolveFn: (value: unknown) => void = () => {}; + const promise = new Promise((resolve) => { + resolveFn = resolve; + }); + + mockedDoPostActionWithQuery.mockImplementation(() => { + return (() => promise) as unknown as ReturnType; + }); + + return {promise, resolve: () => resolveFn({data: {}})}; +} + +describe('InlineActionButton', () => { + const baseProps = { + href: 'mmaction://mx?tail=214&mds=C130J', + postId: 'abc', + children: 'Click me', + }; + + beforeEach(() => { + mockedDoPostActionWithQuery.mockReset(); + }); + + test('renders with children as button label', () => { + mockedDoPostActionWithQuery.mockImplementation( + () => (() => Promise.resolve({data: {}})) as unknown as ReturnType, + ); + + renderWithContext(); + + const button = screen.getByRole('button'); + expect(button).toBeVisible(); + expect(button).toHaveTextContent('Click me'); + }); + + test('dispatches thunk with parsed action ID and query on click', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext(); + + await userEvent.click(screen.getByRole('button')); + + expect(mockedDoPostActionWithQuery).toHaveBeenCalledTimes(1); + expect(mockedDoPostActionWithQuery).toHaveBeenCalledWith('abc', 'mx', {tail: '214', mds: 'C130J'}); + + // Resolve pending dispatch inside act so the trailing setState commits cleanly. + await act(async () => { + resolve(); + }); + }); + + test('href without query results in empty query', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button')); + + expect(mockedDoPostActionWithQuery).toHaveBeenCalledTimes(1); + expect(mockedDoPostActionWithQuery).toHaveBeenCalledWith('abc', 'mx', {}); + + await act(async () => { + resolve(); + }); + }); + + test('mixed-case action ID is preserved (URL.hostname would lowercase it)', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext( + , + ); + + await userEvent.click(screen.getByRole('button')); + + // Server action ID regex allows [A-Za-z0-9]+; losing case would + // cause lookups to 404 when mm_blocks_actions keys are mixed-case. + expect(mockedDoPostActionWithQuery).toHaveBeenCalledWith('abc', 'MxPlan42', {tail: '214'}); + + await act(async () => { + resolve(); + }); + }); + + test('double-click prevented by ref guard', async () => { + // Use a never-resolving promise so the first dispatch stays in-flight. + mockedDoPostActionWithQuery.mockImplementation( + () => (() => new Promise(() => {})) as unknown as ReturnType, + ); + + renderWithContext(); + const button = screen.getByRole('button'); + + // Fire two synchronous clicks before any microtask can run. + // fireEvent.click invokes the handler synchronously; the ref guard + // must block the second invocation before setState re-render lands. + fireEvent.click(button); + fireEvent.click(button); + + expect(mockedDoPostActionWithQuery).toHaveBeenCalledTimes(1); + + // Let any pending microtasks settle so teardown is clean. The dispatch + // promise never resolves, which is fine — we only care about the guard. + await act(async () => { + await Promise.resolve(); + }); + }); + + test('button uses aria-disabled (not native disabled) while executing — keeps it focusable for screen readers', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext(); + const button = screen.getByRole('button'); + + await userEvent.click(button); + + // Native `disabled` would remove the button from tab order; use + // aria-disabled so keyboard / screen-reader users can still + // navigate to it and hear "executing" announced via aria-busy. + // WCAG 2.1.1 + 4.1.3. + expect(button).not.toBeDisabled(); + expect(button).toHaveAttribute('aria-disabled', 'true'); + expect(button).toHaveAttribute('aria-busy', 'true'); + + await act(async () => { + resolve(); + }); + + // Idle state omits aria-disabled and aria-busy entirely (using + // {executing || undefined} idiom) so screen readers don't + // announce "not busy" / "not disabled" superfluously. + expect(button).not.toBeDisabled(); + expect(button).not.toHaveAttribute('aria-disabled'); + expect(button).not.toHaveAttribute('aria-busy'); + }); + + test('ref guard no-ops repeat clicks while aria-disabled (no native disabled to suppress)', async () => { + // Without native `disabled`, the browser fires onClick on + // aria-disabled buttons. The component's executingRef guard must + // catch the second click and no-op. + mockedDoPostActionWithQuery.mockImplementation( + () => (() => new Promise(() => {})) as unknown as ReturnType, + ); + + renderWithContext(); + const button = screen.getByRole('button'); + + fireEvent.click(button); + fireEvent.click(button); + + expect(mockedDoPostActionWithQuery).toHaveBeenCalledTimes(1); + + await act(async () => { + await Promise.resolve(); + }); + }); + + test('unmount during dispatch does not warn about setState on unmounted component', async () => { + const {resolve} = setupControllablePromise(); + + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(() => {}); + + const {unmount} = renderWithContext(); + + await userEvent.click(screen.getByRole('button')); + + unmount(); + + // Resolve the in-flight dispatch after unmount; the component's + // mountedRef guard should prevent a stale setState call. + await act(async () => { + resolve(); + }); + + const unmountedWarnings = consoleErrorSpy.mock.calls.filter((args) => { + const msg = args[0]; + return typeof msg === 'string' && msg.includes('unmounted'); + }); + expect(unmountedWarnings).toHaveLength(0); + + consoleErrorSpy.mockRestore(); + }); + + test('renders {children} as plain text when postId is empty', () => { + const {container} = renderWithContext( + , + ); + + // No button is rendered; the link body shows as plain text instead + // so the user sees something readable rather than a broken affordance. + expect(screen.queryByRole('button')).toBeNull(); + expect(container).toHaveTextContent('Click me'); + }); + + test('renders {children} as plain text when href has wrong scheme', () => { + const {container} = renderWithContext( + , + ); + + expect(screen.queryByRole('button')).toBeNull(); + expect(container).toHaveTextContent('Click me'); + }); + + test('renders {children} as plain text for opaque mmaction: URI (no //)', () => { + // getScheme()-style accept of "mmaction:foo" without "//" would + // mis-slice the authority. Component must reject and fall back. + const {container} = renderWithContext( + , + ); + + expect(screen.queryByRole('button')).toBeNull(); + expect(container).toHaveTextContent('Click me'); + }); + + test('renders {children} as plain text for non-alphanumeric action ID', () => { + // Server regex is ^[A-Za-z0-9]+$; URL authority chars (port, + // userinfo, dash, dot) would never resolve server-side. + for (const href of ['mmaction://plan:443', 'mmaction://user@plan', 'mmaction://my-plan', 'mmaction://my.plan']) { + const {container, unmount} = renderWithContext( + , + ); + expect(screen.queryByRole('button')).toBeNull(); + expect(container).toHaveTextContent('Click me'); + unmount(); + } + }); + + test('renders {children} as plain text when params exceed length cap', () => { + // 2049-char query string is over MAX_PARAMS_LENGTH (2048). + const {container} = renderWithContext( + , + ); + + expect(screen.queryByRole('button')).toBeNull(); + expect(container).toHaveTextContent('Click me'); + }); + + test('aria-label without label prop: idle has none (children carries the name), executing has executing-label', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext(); + const button = screen.getByRole('button'); + + // Idle, no label prop: accessible name comes from {children}. + expect(button).not.toHaveAttribute('aria-label'); + + await userEvent.click(button); + + expect(button).toHaveAttribute('aria-label', 'Executing...'); + + await act(async () => { + resolve(); + }); + + expect(button).not.toHaveAttribute('aria-label'); + }); + + test('aria-label uses label prop at idle (icon-only callers must pass it)', async () => { + const {resolve} = setupControllablePromise(); + + renderWithContext( + +