diff --git a/src/acl.c b/src/acl.c index 177077d45f..7eb0b6041b 100644 --- a/src/acl.c +++ b/src/acl.c @@ -499,6 +499,11 @@ void ACLFreeUserAndKillClients(user *u) { * more defensive to set the default user and put * it in non authenticated mode. */ deauthenticateAndCloseClient(c); + continue; + } + /* Kill clients that still hold subscriptions from the deleted user */ + if (dictFind(c->pubsub_subscriptions, u)) { + deauthenticateAndCloseClient(c); } } ACLFreeUser(u); @@ -523,6 +528,14 @@ void ACLCopyUser(user *dst, user *src) { } } +/* Set the user for a client, performing any necessary bookkeeping such as + * updating broadcast tracking state for the user switch. */ +void clientSetUser(client *c, user *new_user) { + user *old = c->user; + c->user = new_user; + trackingBroadcastPostUserSwitch(c, old); +} + /* Given a command ID, this function set by reference 'word' and 'bit' * so that user->allowed_commands[word] will address the right word * where the corresponding bit for the provided ID is stored, and @@ -1497,7 +1510,7 @@ void addAuthErrReply(client *c, robj *err) { int checkPasswordBasedAuth(client *c, robj *username, robj *password) { if (ACLCheckUserCredentials(username,password) == C_OK) { c->authenticated = 1; - c->user = ACLGetUserByName(username->ptr,sdslen(username->ptr)); + clientSetUser(c, ACLGetUserByName(username->ptr,sdslen(username->ptr))); moduleNotifyUserChanged(c); return AUTH_OK; } else { @@ -1969,57 +1982,52 @@ list *getUpcomingChannelList(user *new, user *original) { return upcoming; } -/* Check if the client should be killed because it is subscribed to channels that were - * permitted in the past, are not in the `upcoming` channel list. */ -int ACLShouldKillPubsubClient(client *c, list *upcoming) { +/* Check if a specific user's subscriptions violate the given channel list. + * Returns 1 if any violation is found, 0 otherwise. */ +static int ACLShouldKillForUserSubs(pubsubUserSubs *subs, list *upcoming) { + serverAssert(!pubsubUserSubsIsEmpty(subs)); robj *o; int kill = 0; + dictIterator di; + dictEntry *de; - if (getClientType(c) == CLIENT_TYPE_PUBSUB) { - /* Check for pattern violations. */ - dictIterator di; - dictEntry *de; - dictInitIterator(&di, c->pubsub_patterns); + /* Check for pattern violations. */ + dictInitIterator(&di, subs->patterns); + while (!kill && ((de = dictNext(&di)) != NULL)) { + o = dictGetKey(de); + int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 1); + kill = (res == ACL_DENIED_CHANNEL); + } + dictResetIterator(&di); + + /* Check for global channel violations. */ + if (!kill) { + dictInitIterator(&di, subs->channels); while (!kill && ((de = dictNext(&di)) != NULL)) { o = dictGetKey(de); - int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 1); + int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 0); kill = (res == ACL_DENIED_CHANNEL); } dictResetIterator(&di); - - /* Check for channel violations. */ - if (!kill) { - /* Check for global channels violation. */ - dictInitIterator(&di, c->pubsub_channels); - - while (!kill && ((de = dictNext(&di)) != NULL)) { - o = dictGetKey(de); - int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 0); - kill = (res == ACL_DENIED_CHANNEL); - } - dictResetIterator(&di); - } - if (!kill) { - /* Check for shard channels violation. */ - dictInitIterator(&di, c->pubsubshard_channels); - while (!kill && ((de = dictNext(&di)) != NULL)) { - o = dictGetKey(de); - int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 0); - kill = (res == ACL_DENIED_CHANNEL); - } - dictResetIterator(&di); - } - - if (kill) { - return 1; - } } - return 0; + + /* Check for shard channel violations. */ + if (!kill) { + dictInitIterator(&di, subs->shard_channels); + while (!kill && ((de = dictNext(&di)) != NULL)) { + o = dictGetKey(de); + int res = ACLCheckChannelAgainstList(upcoming, o->ptr, sdslen(o->ptr), 0); + kill = (res == ACL_DENIED_CHANNEL); + } + dictResetIterator(&di); + } + + return kill; } /* Check if the user's existing pub/sub clients violate the ACL pub/sub * permissions specified via the upcoming argument, and kill them if so. */ -void ACLKillPubsubClientsIfNeeded(user *new, user *original) { +static void ACLKillPubsubClientsIfNeeded(user *new, user *original) { /* Do nothing if there are no subscribers. */ if (pubsubTotalSubscriptions() == 0) return; @@ -2033,14 +2041,16 @@ void ACLKillPubsubClientsIfNeeded(user *new, user *original) { listNode *ln; /* Permissions have changed, so we need to iterate through all - * the clients and disconnect those that are no longer valid. - * Scan all connected clients to find the user's pub/subs. */ + * the clients and disconnect those that hold subscriptions + * created under this user that are no longer valid. */ listRewind(server.clients,&li); while ((ln = listNext(&li)) != NULL) { client *c = listNodeValue(ln); - if (c->user != original) - continue; - if (ACLShouldKillPubsubClient(c, channels)) + if (getClientType(c) != CLIENT_TYPE_PUBSUB) continue; + dictEntry *de = dictFind(c->pubsub_subscriptions, original); + if (!de) continue; + pubsubUserSubs *subs = dictGetVal(de); + if (ACLShouldKillForUserSubs(subs, channels)) deauthenticateAndCloseClient(c); } @@ -2441,7 +2451,15 @@ sds ACLLoadFromFile(const char *filename) { if (sdslen(errors) == 0) { /* The default user pointer is referenced in different places: instead * of replacing such occurrences it is much simpler to copy the new - * default user configuration in the old one. */ + * default user configuration in the old one. Snapshot the old default + * into old_users before mutation so the provenance loop can compare + * against the pre-load permissions. */ + user *old_default_copy = zmalloc(sizeof(user)); + memset(old_default_copy, 0, sizeof(user)); + serverAssert(DefaultUser); + old_default_copy->name = sdsdup(DefaultUser->name); + ACLCopyUser(old_default_copy, DefaultUser); + user *new_default = ACLGetUserByName("default",7); if (!new_default) { new_default = ACLCreateDefaultUser(); @@ -2450,7 +2468,7 @@ sds ACLLoadFromFile(const char *filename) { ACLCopyUser(DefaultUser,new_default); ACLFreeUser(new_default); raxInsert(Users,(unsigned char*)"default",7,DefaultUser,NULL); - raxRemove(old_users,(unsigned char*)"default",7,NULL); + raxInsert(old_users,(unsigned char*)"default",7,old_default_copy,NULL); /* If there are some subscribers, we need to check if we need to drop some clients. */ rax *user_channels = NULL; @@ -2467,21 +2485,67 @@ sds ACLLoadFromFile(const char *filename) { /* a MASTER client can do everything (and user = NULL) so we can skip it */ if (c->flags & CLIENT_MASTER) continue; - user *original = c->user; - list *channels = NULL; - user *new = ACLGetUserByName(c->user->name, sdslen(c->user->name)); - if (new && user_channels) { - if (!raxFind(user_channels, (unsigned char*)(new->name), sdslen(new->name), (void**)&channels)) { - channels = getUpcomingChannelList(new, original); - raxInsert(user_channels, (unsigned char*)(new->name), sdslen(new->name), channels, NULL); - } - } - /* When the new channel list is NULL, it means the new user's channel list is a superset of the old user's list. */ - if (!new || (channels && ACLShouldKillPubsubClient(c, channels))) { + + /* Reassign c->user to the new user object (or kill if gone). */ + user *new_current = ACLGetUserByName(c->user->name, sdslen(c->user->name)); + if (!new_current) { deauthenticateAndCloseClient(c); continue; } - c->user = new; + + /* Phase 1: Validate provenance entries (read-only, no mutation). + * Old user pointers are still alive — old_users is freed after + * the walk — so old_user_ptr->name is safe to dereference. */ + int must_kill = 0; + if (user_channels) { + dictIterator di; + dictEntry *entry; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((entry = dictNext(&di)) != NULL) { + user *old_user_ptr = dictGetKey(entry); + if (pubsubUserIsNoAuth(old_user_ptr)) continue; + sds prov_username = old_user_ptr->name; + + user *new_prov = ACLGetUserByName(prov_username, sdslen(prov_username)); + if (!new_prov) { + must_kill = 1; + break; + } + + list *channels = NULL; + if (!raxFind(user_channels, (unsigned char*)prov_username, sdslen(prov_username), (void**)&channels)) { + user *old_prov = NULL; + raxFind(old_users, (unsigned char*)prov_username, sdslen(prov_username), (void**)&old_prov); + if (!old_prov) { + must_kill = 1; + break; + } + channels = getUpcomingChannelList(new_prov, old_prov); + raxInsert(user_channels, (unsigned char*)prov_username, sdslen(prov_username), channels, NULL); + } + + if (channels != NULL) { + pubsubUserSubs *subs = dictGetVal(entry); + if (ACLShouldKillForUserSubs(subs, channels)) { + must_kill = 1; + break; + } + } + } + dictResetIterator(&di); + } + + if (must_kill) { + deauthenticateAndCloseClient(c); + continue; + } + + /* Phase 2: Client survived — re-key provenance entries from old + * user pointers to new user pointers, then reassign c->user. */ + if (dictSize(c->pubsub_subscriptions) > 0) { + pubsubRekeySubscriptionsForACLLoad(c); + } + clientSetUser(c, new_current); } if (user_channels) @@ -3241,7 +3305,7 @@ static void internalAuth(client *c) { c->authenticated = 1; /* Set the user to the unrestricted user, if it is not already set (default). */ if (c->user != NULL) { - c->user = NULL; + clientSetUser(c, NULL); moduleNotifyUserChanged(c); } addReply(c, shared.ok); diff --git a/src/defrag.c b/src/defrag.c index 913e457c25..ccebeaca62 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -127,13 +127,17 @@ typedef struct { } defragSubexpiresCtx; /* Context for pubsub kvstores */ -typedef dict *(*getClientChannelsFn)(client *); typedef struct { kvstoreIterState kvstate; - getClientChannelsFn getPubSubChannels; + int shard; } defragPubSubCtx; static_assert(offsetof(defragPubSubCtx, kvstate) == 0, "defragStageKvstoreHelper requires this"); +/* Context for client-side pubsub defrag (iterates server.clients_index) */ +typedef struct { + uint64_t last_client_id_raw; /* big-endian rax key of last processed client, 0 = start */ +} defragPubsubClientCtx; + typedef struct { sds module_name; unsigned long cursor; @@ -1278,17 +1282,30 @@ void defragPubsubScanCallback(void *privdata, const dictEntry *de, dictEntryLink /* The channel name is shared by the client's pubsub(shard) and server's * pubsub(shard), after defraging the channel name, we need to update - * the reference in the clients' dictionary. */ + * the reference in the clients' per-user inner dicts. */ dictIterator di; dictEntry *clientde; dictInitIterator(&di, clients); while((clientde = dictNext(&di)) != NULL) { client *c = dictGetKey(clientde); - dict *client_channels = ctx->getPubSubChannels(c); - uint64_t hash = dictGetHash(client_channels, newchannel); - dictEntry *pubsub_channel = dictFindByHashAndPtr(client_channels, channel, hash); - serverAssert(pubsub_channel); - dictSetKey(ctx->getPubSubChannels(c), pubsub_channel, newchannel); + /* Scan the per-user dict to find which inner dict holds the old pointer */ + dictIterator udi; + dictEntry *userEntry; + int found = 0; + dictInitIterator(&udi, c->pubsub_subscriptions); + while ((userEntry = dictNext(&udi)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + dict *inner = ctx->shard ? subs->shard_channels : subs->channels; + uint64_t hash = dictGetHash(inner, newchannel); + dictEntry *pubsub_channel = dictFindByHashAndPtr(inner, channel, hash); + if (pubsub_channel) { + dictSetKey(inner, pubsub_channel, newchannel); + found = 1; + break; + } + } + dictResetIterator(&udi); + serverAssert(found); } dictResetIterator(&di); } @@ -1609,6 +1626,65 @@ static doneStatus defragStagePubsubKvstore(void *ctx, monotime endtime) { defragPubsubScanCallback, NULL, &defragfns); } +/* Defrag client-side per-user pubsub dict structures. + * This handles the outer dict tables, pubsubUserSubs structs, and inner dict + * tables. Outer dict keys are user* pointers (user objects are not moved by + * active defrag today; if that changes, all user-pointer indexes must be + * updated). Inner dict key objects (robj) are NOT touched here - they are + * handled by the server-side defrag callbacks above. + * + * Iteration uses server.clients_index (a rax keyed by big-endian client ID) + * so that clients deleted between calls are safely skipped, and no list + * mutation is required. */ +static doneStatus defragStagePubsubClientSide(void *ctx, monotime endtime) { + defragPubsubClientCtx *dctx = ctx; + unsigned int iterations = 0; + + raxIterator ri; + raxStart(&ri, server.clients_index); + raxSeek(&ri, ">", (unsigned char *)&dctx->last_client_id_raw, + sizeof(dctx->last_client_id_raw)); + + while (raxNext(&ri)) { + client *c = ri.data; + + if (dictSize(c->pubsub_subscriptions) > 0) { + dict *newd = dictDefragTables(c->pubsub_subscriptions); + if (newd) c->pubsub_subscriptions = newd; + + dictIterator di; + dictEntry *de; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((de = dictNext(&di)) != NULL) { + pubsubUserSubs *subs = dictGetVal(de); + pubsubUserSubs *newsubs = activeDefragAlloc(subs); + if (newsubs) { + dictSetVal(c->pubsub_subscriptions, de, newsubs); + subs = newsubs; + } + + dict *newinner; + if ((newinner = dictDefragTables(subs->channels))) + subs->channels = newinner; + if ((newinner = dictDefragTables(subs->patterns))) + subs->patterns = newinner; + if ((newinner = dictDefragTables(subs->shard_channels))) + subs->shard_channels = newinner; + } + dictResetIterator(&di); + } + + if (++iterations >= 16 && getMonotonicUs() >= endtime) { + memcpy(&dctx->last_client_id_raw, ri.key, sizeof(dctx->last_client_id_raw)); + raxStop(&ri); + return DEFRAG_NOT_DONE; + } + } + + raxStop(&ri); + return DEFRAG_DONE; +} + static doneStatus defragLuaScripts(void *ctx, monotime endtime) { UNUSED(endtime); UNUSED(ctx); @@ -1935,15 +2011,20 @@ static void beginDefragCycle(void) { /* Add stage for pubsub channels. */ defragPubSubCtx *defrag_pubsub_ctx = zmalloc(sizeof(defragPubSubCtx)); defrag_pubsub_ctx->kvstate = INIT_KVSTORE_STATE(server.pubsub_channels); - defrag_pubsub_ctx->getPubSubChannels = getClientPubSubChannels; + defrag_pubsub_ctx->shard = 0; addDefragStage(defragStagePubsubKvstore, zfree, defrag_pubsub_ctx); /* Add stage for pubsubshard channels. */ defragPubSubCtx *defrag_pubsubshard_ctx = zmalloc(sizeof(defragPubSubCtx)); defrag_pubsubshard_ctx->kvstate = INIT_KVSTORE_STATE(server.pubsubshard_channels); - defrag_pubsubshard_ctx->getPubSubChannels = getClientPubSubShardChannels; + defrag_pubsubshard_ctx->shard = 1; addDefragStage(defragStagePubsubKvstore, zfree, defrag_pubsubshard_ctx); + /* Add stage for client-side pubsub per-user dict structures. */ + defragPubsubClientCtx *defrag_pubsub_client_ctx = zmalloc(sizeof(defragPubsubClientCtx)); + defrag_pubsub_client_ctx->last_client_id_raw = 0; + addDefragStage(defragStagePubsubClientSide, zfree, defrag_pubsub_client_ctx); + addDefragStage(defragLuaScripts, NULL, NULL); /* Add stages for modules. */ diff --git a/src/module.c b/src/module.c index 50a594987a..6221816b44 100644 --- a/src/module.c +++ b/src/module.c @@ -10809,8 +10809,8 @@ static int authenticateClientWithUser(RedisModuleCtx *ctx, user *user, RedisModu moduleNotifyUserChanged(ctx->client); - ctx->client->user = user; ctx->client->authenticated = 1; + clientSetUser(ctx->client, user); if (clientHasModuleAuthInProgress(ctx->client)) { ctx->client->flags |= CLIENT_MODULE_AUTH_HAS_RESULT; diff --git a/src/networking.c b/src/networking.c index 2f5384c3b9..c9494f435f 100644 --- a/src/networking.c +++ b/src/networking.c @@ -103,7 +103,7 @@ void linkClient(client *c) { static void clientSetDefaultAuth(client *c) { /* If the default user does not require authentication, the user is * directly authenticated. */ - c->user = DefaultUser; + clientSetUser(c, DefaultUser); c->authenticated = (c->user->flags & USER_FLAG_NOPASS) && !(c->user->flags & USER_FLAG_DISABLED); } @@ -193,6 +193,7 @@ client *createClient(connection *conn) { c->ctime = c->lastinteraction = server.unixtime; c->io_lastinteraction = 0; c->duration = 0; + c->user = DefaultUser; /* Set a safe default value: clientSetDefaultAuth reads c->user. */ clientSetDefaultAuth(c); c->replstate = REPL_STATE_NONE; c->repl_start_cmd_stream_on_ack = 0; @@ -220,9 +221,10 @@ client *createClient(connection *conn) { initClientBlockingState(c); c->woff = 0; c->watched_keys = listCreate(); - c->pubsub_channels = dictCreate(&objectKeyPointerValueDictType); - c->pubsub_patterns = dictCreate(&objectKeyPointerValueDictType); - c->pubsubshard_channels = dictCreate(&objectKeyPointerValueDictType); + c->pubsub_subscriptions = dictCreate(&pubsubSubscriptionsDictType); + c->pubsub_channels_count = 0; + c->pubsub_patterns_count = 0; + c->pubsubshard_channels_count = 0; c->peerid = NULL; c->sockname = NULL; c->client_list_node = NULL; @@ -1614,8 +1616,8 @@ void clientAcceptHandler(connection *conn) { if (username != NULL) { user *u = ACLGetUserByName(username, sdslen(username)); if (u && !(u->flags & USER_FLAG_DISABLED)) { - c->user = u; c->authenticated = 1; + clientSetUser(c, u); moduleNotifyUserChanged(c); serverLog(LL_VERBOSE, "TLS: Auto-authenticated client as %s", server.hide_user_data_from_log ? "*redacted*" : u->name); @@ -2058,6 +2060,11 @@ void clearClientConnectionState(client *c) { pubsubUnsubscribeShardAllChannels(c, 0); pubsubUnsubscribeAllPatterns(c,0); unmarkClientAsPubSub(c); + dictRelease(c->pubsub_subscriptions); + c->pubsub_subscriptions = dictCreate(&pubsubSubscriptionsDictType); + c->pubsub_channels_count = 0; + c->pubsub_patterns_count = 0; + c->pubsubshard_channels_count = 0; if (c->name) { decrRefCount(c->name); @@ -2073,6 +2080,7 @@ void clearClientConnectionState(client *c) { } void deauthenticateAndCloseClient(client *c) { + disableTracking(c); c->user = DefaultUser; c->authenticated = 0; /* We will write replies to this client later, so we can't @@ -2242,9 +2250,7 @@ void freeClient(client *c) { pubsubUnsubscribeShardAllChannels(c, 0); pubsubUnsubscribeAllPatterns(c,0); unmarkClientAsPubSub(c); - dictRelease(c->pubsub_channels); - dictRelease(c->pubsub_patterns); - dictRelease(c->pubsubshard_channels); + dictRelease(c->pubsub_subscriptions); /* Free data structures. */ releaseAllBufReferences(c); /* Release all references to string objects in encoded buffers before freeing */ @@ -4110,9 +4116,9 @@ sds catClientInfoString(sds s, client *client) { " idle=%I", (long long)(server.unixtime - client->lastinteraction), " flags=%s", flags, " db=%i", client->db->id, - " sub=%i", (int) dictSize(client->pubsub_channels), - " psub=%i", (int) dictSize(client->pubsub_patterns), - " ssub=%i", (int) dictSize(client->pubsubshard_channels), + " sub=%i", (int) client->pubsub_channels_count, + " psub=%i", (int) client->pubsub_patterns_count, + " ssub=%i", (int) client->pubsubshard_channels_count, " multi=%i", (client->flags & CLIENT_MULTI) ? client->mstate.count : -1, " watch=%i", (int) listLength(client->watched_keys), " qbuf=%U", client->querybuf ? (unsigned long long) sdslen(client->querybuf) : 0, diff --git a/src/pubsub.c b/src/pubsub.c index b9198d2639..58008540d1 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -20,7 +20,6 @@ * for pubsub and pubsubshard feature. */ typedef struct pubsubtype { int shard; - dict *(*clientPubSubChannels)(client*); int (*subscriptionCount)(client*); kvstore **serverPubSubChannels; robj **subscribeMsg; @@ -28,25 +27,80 @@ typedef struct pubsubtype { robj **messageBulk; }pubsubtype; -/* - * Get client's global Pub/Sub channels subscription count. - */ -int clientSubscriptionsCount(client *c); +/* -------------------------------------------------------------------------- + * Per-user subscription dict helpers + * -------------------------------------------------------------------------- */ -/* - * Get client's shard level Pub/Sub channels subscription count. - */ -int clientShardSubscriptionsCount(client *c); +static void freePubsubUserSubs(dict *d, void *val) { + UNUSED(d); + pubsubUserSubs *subs = val; + dictRelease(subs->channels); + dictRelease(subs->patterns); + dictRelease(subs->shard_channels); + zfree(subs); +} -/* - * Get client's global Pub/Sub channels dict. - */ -dict* getClientPubSubChannels(client *c); +dictType pubsubSubscriptionsDictType = { + dictPtrHash, + NULL, + NULL, + NULL, + NULL, + freePubsubUserSubs, + NULL +}; -/* - * Get client's shard level Pub/Sub channels dict. - */ -dict* getClientPubSubShardChannels(client *c); +static dictType pubsubNoDestructorDictType = { + dictPtrHash, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL +}; + +/* Sentinel user for clients with c->user == NULL (e.g. CLIENT_MASTER, + * module temp clients with RM_Call "K" flag). This is a static object + * that is never registered in the ACL system, so no ACL operation will + * ever match or dereference it. It lets us keep a non-NULL dict key + * in pubsub_subscriptions without changing protocol behaviour. */ +static user pubsubNoAuthUser; + +int pubsubUserIsNoAuth(user *u) { + return u == &pubsubNoAuthUser; +} + +static pubsubUserSubs *createPubsubUserSubs(void) { + pubsubUserSubs *subs = zmalloc(sizeof(*subs)); + subs->channels = dictCreate(&objectKeyPointerValueDictType); + subs->patterns = dictCreate(&objectKeyPointerValueDictType); + subs->shard_channels = dictCreate(&objectKeyPointerValueDictType); + return subs; +} + +static pubsubUserSubs *pubsubGetOrCreateUserSubs(client *c) { + user *key = c->user ? c->user : &pubsubNoAuthUser; + dictEntry *de = dictFind(c->pubsub_subscriptions, key); + if (de) return dictGetVal(de); + pubsubUserSubs *subs = createPubsubUserSubs(); + serverAssert(dictAdd(c->pubsub_subscriptions, key, subs) == DICT_OK); + return subs; +} + +int pubsubUserSubsIsEmpty(pubsubUserSubs *subs) { + return dictSize(subs->channels) == 0 + && dictSize(subs->patterns) == 0 + && dictSize(subs->shard_channels) == 0; +} + +static dict *pubsubUserSubsGetDict(pubsubUserSubs *subs, pubsubtype type) { + return type.shard ? subs->shard_channels : subs->channels; +} + +static size_t *pubsubClientCountPtr(client *c, pubsubtype type) { + return type.shard ? &c->pubsubshard_channels_count : &c->pubsub_channels_count; +} /* * Get list of channels client is subscribed to. @@ -60,7 +114,6 @@ void channelList(client *c, sds pat, kvstore *pubsub_channels); */ pubsubtype pubSubType = { .shard = 0, - .clientPubSubChannels = getClientPubSubChannels, .subscriptionCount = clientSubscriptionsCount, .serverPubSubChannels = &server.pubsub_channels, .subscribeMsg = &shared.subscribebulk, @@ -73,7 +126,6 @@ pubsubtype pubSubType = { */ pubsubtype pubSubShardType = { .shard = 1, - .clientPubSubChannels = getClientPubSubShardChannels, .subscriptionCount = clientShardSubscriptionsCount, .serverPubSubChannels = &server.pubsubshard_channels, .subscribeMsg = &shared.ssubscribebulk, @@ -204,26 +256,19 @@ int serverPubsubShardSubscriptionCount(void) { /* Return the number of channels + patterns a client is subscribed to. */ int clientSubscriptionsCount(client *c) { - return dictSize(c->pubsub_channels) + dictSize(c->pubsub_patterns); + return (int)(c->pubsub_channels_count + c->pubsub_patterns_count); } /* Return the number of shard level channels a client is subscribed to. */ int clientShardSubscriptionsCount(client *c) { - return dictSize(c->pubsubshard_channels); -} - -dict* getClientPubSubChannels(client *c) { - return c->pubsub_channels; -} - -dict* getClientPubSubShardChannels(client *c) { - return c->pubsubshard_channels; + return (int)c->pubsubshard_channels_count; } /* Return the number of pubsub + pubsub shard level channels * a client is subscribed to. */ int clientTotalPubSubSubscriptionCount(client *c) { - return clientSubscriptionsCount(c) + clientShardSubscriptionsCount(c); + return (int)(c->pubsub_channels_count + c->pubsub_patterns_count + + c->pubsubshard_channels_count); } void markClientAsPubSub(client *c) { @@ -240,6 +285,25 @@ void unmarkClientAsPubSub(client *c) { } } +/* Check if a client is subscribed to a channel/shard-channel under any user. */ +static int pubsubClientIsSubscribedChannel(client *c, robj *channel, pubsubtype type) { + dictIterator di; + dictEntry *userEntry; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((userEntry = dictNext(&di)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + serverAssert(subs != NULL); + dict *innerDict = pubsubUserSubsGetDict(subs, type); + serverAssert(innerDict != NULL); + if (dictFind(innerDict, channel)) { + dictResetIterator(&di); + return 1; + } + } + dictResetIterator(&di); + return 0; +} + /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or * 0 if the client was already subscribed to that channel. */ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { @@ -248,11 +312,14 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { int retval = 0; unsigned int slot = 0; - /* Add the channel to the client -> channels hash table */ - dictEntryLink bucket; - dictEntryLink link = dictFindLink(type.clientPubSubChannels(c),channel,&bucket); - if (link == NULL) { /* Not yet subscribed to this channel */ + /* Dedup: check if subscribed under any user */ + if (!pubsubClientIsSubscribedChannel(c, channel, type)) { retval = 1; + + /* Look up or create per-user entry for current user */ + pubsubUserSubs *subs = pubsubGetOrCreateUserSubs(c); + dict *innerDict = pubsubUserSubsGetDict(subs, type); + /* Add the client to the channel -> list of clients hash table */ if (server.cluster_enabled && type.shard) { slot = getKeySlot(channel->ptr); @@ -270,47 +337,82 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { } serverAssert(dictAdd(clients, c, NULL) != DICT_ERR); - dictSetKeyAtLink(type.clientPubSubChannels(c), channel, &bucket, 1); + serverAssert(dictAdd(innerDict, channel, NULL) != DICT_ERR); incrRefCount(channel); + + (*pubsubClientCountPtr(c, type))++; } /* Notify the client */ addReplyPubsubSubscribed(c,channel,type); return retval; } +/* Remove a channel from a known inner dict + server-side reverse mapping. + * Does NOT scan the outer dict or delete outer entries. + * Caller is responsible for outer entry lifecycle. */ +static void pubsubUnsubscribeKnownChannel(client *c, dict *innerDict, + robj *channel, int notify, pubsubtype type) { + dictEntry *de; + dict *clients; + int slot = 0; + + incrRefCount(channel); + serverAssert(dictDelete(innerDict, channel) == DICT_OK); + + if (server.cluster_enabled && type.shard) { + /* Compute the slot from the channel directly instead of using getKeySlot(), + * because the unsubscribe may be triggered by a different client, and + * getKeySlot() would return the cached slot of that client. */ + slot = keyHashSlot(channel->ptr, sdslen(channel->ptr)); + } + de = kvstoreDictFind(*type.serverPubSubChannels, slot, channel); + serverAssertWithInfo(c, NULL, de != NULL); + clients = dictGetVal(de); + serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); + if (dictSize(clients) == 0) { + /* Free the dict and associated hash entry at all if this was + * the latest client, so that it will be possible to abuse + * Redis PUBSUB creating millions of channels. */ + kvstoreDictDelete(*type.serverPubSubChannels, slot, channel); + } + + serverAssert(*pubsubClientCountPtr(c, type) > 0); + (*pubsubClientCountPtr(c, type))--; + + if (notify) { + addReplyPubsubUnsubscribed(c,channel,type); + } + decrRefCount(channel); +} + /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or * 0 if the client was not subscribed to the specified channel. */ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { - dictEntry *de; - dict *clients; int retval = 0; - int slot = 0; - /* Remove the channel from the client -> channels hash table */ + /* Remove the channel from the client's subscription bookkeeping */ incrRefCount(channel); /* channel may be just a pointer to the same object we have in the hash tables. Protect it... */ - if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) { - retval = 1; - /* Remove the client from the channel -> clients list hash table */ - if (server.cluster_enabled && type.shard) { - /* Compute the slot from the channel directly instead of using getKeySlot(), - * because the unsubscribe may be triggered by a different client, and - * getKeySlot() would return the cached slot of that client. */ - slot = keyHashSlot(channel->ptr, sdslen(channel->ptr)); - } - de = kvstoreDictFind(*type.serverPubSubChannels, slot, channel); - serverAssertWithInfo(c,NULL,de != NULL); - clients = dictGetVal(de); - serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); - if (dictSize(clients) == 0) { - /* Free the dict and associated hash entry at all if this was - * the latest client, so that it will be possible to abuse - * Redis PUBSUB creating millions of channels. */ - kvstoreDictDelete(*type.serverPubSubChannels, slot, channel); + /* Scan per-user entries to find which one holds this channel */ + dictIterator outer; + dictEntry *userEntry; + dictInitSafeIterator(&outer, c->pubsub_subscriptions); + while ((userEntry = dictNext(&outer)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + dict *innerDict = pubsubUserSubsGetDict(subs, type); + if (dictFind(innerDict, channel)) { + pubsubUnsubscribeKnownChannel(c, innerDict, channel, notify, type); + if (pubsubUserSubsIsEmpty(subs)) { + dictDelete(c->pubsub_subscriptions, dictGetKey(userEntry)); + } + retval = 1; + break; } } + dictResetIterator(&outer); + /* Notify the client */ - if (notify) { + if (!retval && notify) { addReplyPubsubUnsubscribed(c,channel,type); } decrRefCount(channel); /* it is finally safe to release it */ @@ -335,8 +437,25 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { dictInitIterator(&iter, clients); while ((entry = dictNext(&iter)) != NULL) { client *c = dictGetKey(entry); - int retval = dictDelete(c->pubsubshard_channels, channel); - serverAssertWithInfo(c,channel,retval == DICT_OK); + /* Find and remove from the per-user entry that holds it. */ + int found = 0; + dictIterator di; + dictEntry *userEntry; + dictInitSafeIterator(&di, c->pubsub_subscriptions); + while ((userEntry = dictNext(&di)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + if (dictDelete(subs->shard_channels, channel) == DICT_OK) { + found = 1; + serverAssert(c->pubsubshard_channels_count > 0); + c->pubsubshard_channels_count--; + if (pubsubUserSubsIsEmpty(subs)) { + dictDelete(c->pubsub_subscriptions, dictGetKey(userEntry)); + } + break; + } + } + dictResetIterator(&di); + serverAssertWithInfo(c, channel, found); addReplyPubsubUnsubscribed(c, channel, pubSubShardType); /* If the client has no other pubsub subscription, * move out of pubsub mode. */ @@ -350,14 +469,35 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { kvstoreResetDictIterator(&kvs_di); } +/* Check if a client is subscribed to a pattern under any user. */ +static int pubsubClientIsSubscribedPattern(client *c, robj *pattern) { + dictIterator di; + dictEntry *userEntry; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((userEntry = dictNext(&di)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + serverAssert(subs != NULL); + serverAssert(subs->patterns != NULL); + if (dictFind(subs->patterns, pattern)) { + dictResetIterator(&di); + return 1; + } + } + dictResetIterator(&di); + return 0; +} + /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { dictEntry *de; dict *clients; int retval = 0; - if (dictAdd(c->pubsub_patterns, pattern, NULL) == DICT_OK) { + if (!pubsubClientIsSubscribedPattern(c, pattern)) { retval = 1; + pubsubUserSubs *subs = pubsubGetOrCreateUserSubs(c); + + serverAssert(dictAdd(subs->patterns, pattern, NULL) != DICT_ERR); incrRefCount(pattern); /* Add the client to the pattern -> list of clients hash table */ de = dictFind(server.pubsub_patterns,pattern); @@ -369,55 +509,95 @@ int pubsubSubscribePattern(client *c, robj *pattern) { clients = dictGetVal(de); } serverAssert(dictAdd(clients, c, NULL) != DICT_ERR); + c->pubsub_patterns_count++; } /* Notify the client */ addReplyPubsubPatSubscribed(c,pattern); return retval; } -/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or - * 0 if the client was not subscribed to the specified channel. */ -int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { +/* Helper: unsubscribe a pattern from a known inner dict + server-side mapping. + * Does NOT scan the outer dict or delete outer entries. */ +static void pubsubUnsubscribeKnownPattern(client *c, dict *innerDict, + robj *pattern, int notify) { dictEntry *de; dict *clients; + + incrRefCount(pattern); + serverAssert(dictDelete(innerDict, pattern) == DICT_OK); + + /* Remove the client from the pattern -> clients list hash table */ + de = dictFind(server.pubsub_patterns,pattern); + serverAssertWithInfo(c,NULL,de != NULL); + clients = dictGetVal(de); + serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); + if (dictSize(clients) == 0) { + /* Free the dict and associated hash entry at all if this was + * the latest client. */ + dictDelete(server.pubsub_patterns,pattern); + } + + serverAssert(c->pubsub_patterns_count > 0); + c->pubsub_patterns_count--; + + if (notify) addReplyPubsubPatUnsubscribed(c,pattern); + decrRefCount(pattern); +} + +/* Unsubscribe a client from a pattern. Returns 1 if the operation succeeded, or + * 0 if the client was not subscribed to the specified pattern. */ +int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { int retval = 0; incrRefCount(pattern); /* Protect the object. May be the same we remove */ - if (dictDelete(c->pubsub_patterns, pattern) == DICT_OK) { - retval = 1; - /* Remove the client from the pattern -> clients list hash table */ - de = dictFind(server.pubsub_patterns,pattern); - serverAssertWithInfo(c,NULL,de != NULL); - clients = dictGetVal(de); - serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); - if (dictSize(clients) == 0) { - /* Free the dict and associated hash entry at all if this was - * the latest client. */ - dictDelete(server.pubsub_patterns,pattern); + + dictIterator outer; + dictEntry *userEntry; + dictInitSafeIterator(&outer, c->pubsub_subscriptions); + while ((userEntry = dictNext(&outer)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + if (dictFind(subs->patterns, pattern)) { + pubsubUnsubscribeKnownPattern(c, subs->patterns, pattern, notify); + if (pubsubUserSubsIsEmpty(subs)) { + dictDelete(c->pubsub_subscriptions, dictGetKey(userEntry)); + } + retval = 1; + break; } } + dictResetIterator(&outer); /* Notify the client */ - if (notify) addReplyPubsubPatUnsubscribed(c,pattern); + if (!retval && notify) addReplyPubsubPatUnsubscribed(c,pattern); decrRefCount(pattern); return retval; } -/* Unsubscribe from all the channels. Return the number of channels the - * client was subscribed to. */ -int pubsubUnsubscribeAllChannelsInternal(client *c, int notify, pubsubtype type) { +/* Unsubscribe from all the channels of a given type. Return the number of + * channels the client was subscribed to. */ +static int pubsubUnsubscribeAllChannelsInternal(client *c, int notify, pubsubtype type) { int count = 0; - if (dictSize(type.clientPubSubChannels(c)) > 0) { - dictIterator di; - dictEntry *de; - - dictInitSafeIterator(&di, type.clientPubSubChannels(c)); - while((de = dictNext(&di)) != NULL) { - robj *channel = dictGetKey(de); - - count += pubsubUnsubscribeChannel(c,channel,notify,type); + dictIterator outer; + dictEntry *userEntry; + dictInitSafeIterator(&outer, c->pubsub_subscriptions); + while ((userEntry = dictNext(&outer)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + dict *innerDict = pubsubUserSubsGetDict(subs, type); + if (dictSize(innerDict) > 0) { + dictIterator inner; + dictEntry *de; + dictInitSafeIterator(&inner, innerDict); + while ((de = dictNext(&inner)) != NULL) { + robj *channel = dictGetKey(de); + pubsubUnsubscribeKnownChannel(c, innerDict, channel, notify, type); + count++; + } + dictResetIterator(&inner); + } + if (pubsubUserSubsIsEmpty(subs)) { + dictDelete(c->pubsub_subscriptions, dictGetKey(userEntry)); } - dictResetIterator(&di); } + dictResetIterator(&outer); /* We were subscribed to nothing? Still reply to the client. */ if (notify && count == 0) { addReplyPubsubUnsubscribed(c,NULL,type); @@ -446,17 +626,27 @@ int pubsubUnsubscribeShardAllChannels(client *c, int notify) { int pubsubUnsubscribeAllPatterns(client *c, int notify) { int count = 0; - if (dictSize(c->pubsub_patterns) > 0) { - dictIterator di; - dictEntry *de; - - dictInitSafeIterator(&di, c->pubsub_patterns); - while ((de = dictNext(&di)) != NULL) { - robj *pattern = dictGetKey(de); - count += pubsubUnsubscribePattern(c, pattern, notify); + dictIterator outer; + dictEntry *userEntry; + dictInitSafeIterator(&outer, c->pubsub_subscriptions); + while ((userEntry = dictNext(&outer)) != NULL) { + pubsubUserSubs *subs = dictGetVal(userEntry); + if (dictSize(subs->patterns) > 0) { + dictIterator inner; + dictEntry *de; + dictInitSafeIterator(&inner, subs->patterns); + while ((de = dictNext(&inner)) != NULL) { + robj *pattern = dictGetKey(de); + pubsubUnsubscribeKnownPattern(c, subs->patterns, pattern, notify); + count++; + } + dictResetIterator(&inner); + } + if (pubsubUserSubsIsEmpty(subs)) { + dictDelete(c->pubsub_subscriptions, dictGetKey(userEntry)); } - dictResetIterator(&di); } + dictResetIterator(&outer); /* We were subscribed to nothing? Still reply to the client. */ if (notify && count == 0) addReplyPubsubPatUnsubscribed(c,NULL); @@ -755,15 +945,59 @@ void sunsubscribeCommand(client *c) { } size_t pubsubMemOverhead(client *c) { - /* PubSub patterns */ - size_t mem = dictMemUsage(c->pubsub_patterns); - /* Global PubSub channels */ - mem += dictMemUsage(c->pubsub_channels); - /* Sharded PubSub channels */ - mem += dictMemUsage(c->pubsubshard_channels); + size_t mem = dictMemUsage(c->pubsub_subscriptions); + dictIterator di; + dictEntry *de; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((de = dictNext(&di)) != NULL) { + pubsubUserSubs *subs = dictGetVal(de); + mem += zmalloc_size(subs); + mem += dictMemUsage(subs->channels); + mem += dictMemUsage(subs->patterns); + mem += dictMemUsage(subs->shard_channels); + } + dictResetIterator(&di); return mem; } +/* Re-key c->pubsub_subscriptions from old user pointers to new user pointers + * after ACL LOAD. Called only after phase 1 validation has confirmed the client + * survives (all provenance users still exist with acceptable permissions). + * + * Builds a new dict with new user* keys and transfers ownership of the + * pubsubUserSubs values. The old dict is released without value destructors. */ +void pubsubRekeySubscriptionsForACLLoad(client *c) { + dict *new_dict = dictCreate(&pubsubSubscriptionsDictType); + + /* Walk the old dict and re-insert each entry under the corresponding + * new user pointer. old_user_ptr is still alive here (old_users rax + * is freed after the full client walk), so ->name is safe to read. */ + dictIterator di; + dictEntry *entry; + dictInitIterator(&di, c->pubsub_subscriptions); + while ((entry = dictNext(&di)) != NULL) { + user *old_user_ptr = dictGetKey(entry); + pubsubUserSubs *subs = dictGetVal(entry); + + if (pubsubUserIsNoAuth(old_user_ptr)) { + /* Sentinel key is a stable static pointer — carry it as-is. */ + serverAssert(dictAdd(new_dict, old_user_ptr, subs) == DICT_OK); + } else { + user *new_user = ACLGetUserByName(old_user_ptr->name, sdslen(old_user_ptr->name)); + serverAssert(new_user != NULL); + serverAssert(dictAdd(new_dict, new_user, subs) == DICT_OK); + } + } + dictResetIterator(&di); + + /* Swap the old dict out without freeing the values — new_dict now owns + * them. We temporarily switch the old dict's type to one with no + * destructors so dictRelease only frees the table structure. */ + c->pubsub_subscriptions->type = &pubsubNoDestructorDictType; + dictRelease(c->pubsub_subscriptions); + c->pubsub_subscriptions = new_dict; +} + int pubsubTotalSubscriptions(void) { return dictSize(server.pubsub_patterns) + kvstoreSize(server.pubsub_channels) + diff --git a/src/server.h b/src/server.h index 9318eec686..ed21375cef 100644 --- a/src/server.h +++ b/src/server.h @@ -1380,6 +1380,13 @@ typedef struct { robj *acl_string; /* cached string represent of ACLs */ } user; +/* Per-user Pub/Sub subscription state. */ +typedef struct pubsubUserSubs { + dict *channels; /* channels a client is interested in (SUBSCRIBE) */ + dict *patterns; /* patterns a client is interested in (PSUBSCRIBE) */ + dict *shard_channels; /* shard level channels a client is interested in (SSUBSCRIBE) */ +} pubsubUserSubs; + /* With multiplexing we need to take per-client state. * Clients are taken in a linked list. */ @@ -1574,9 +1581,10 @@ typedef struct client { blockingState bstate; /* blocking state */ long long woff; /* Last write global replication offset. */ list *watched_keys; /* Keys WATCHED for MULTI/EXEC CAS */ - dict *pubsub_channels; /* channels a client is interested in (SUBSCRIBE) */ - dict *pubsub_patterns; /* patterns a client is interested in (PSUBSCRIBE) */ - dict *pubsubshard_channels; /* shard level channels a client is interested in (SSUBSCRIBE) */ + dict *pubsub_subscriptions; /* user* -> pubsubUserSubs* */ + size_t pubsub_channels_count; + size_t pubsub_patterns_count; + size_t pubsubshard_channels_count; sds peerid; /* Cached peer ID. */ sds sockname; /* Cached connection target address. */ listNode *client_list_node; /* list node in client list */ @@ -3062,6 +3070,7 @@ extern struct sharedObjectsStruct shared; extern dictType objectKeyPointerValueDictType; extern dictType objectKeyNoValueDictType; extern dictType objectKeyHeapPointerValueDictType; +extern dictType pubsubSubscriptionsDictType; extern dictType setDictType; extern dictType BenchmarkDictType; extern dictType zsetDictType; @@ -3361,6 +3370,8 @@ uint64_t trackingGetTotalItems(void); uint64_t trackingGetTotalKeys(void); uint64_t trackingGetTotalPrefixes(void); void trackingBroadcastInvalidationMessages(void); +void trackingBroadcastPostUserSwitch(client *c, user *old_user); +void clientSetUser(client *c, user *new_user); int checkPrefixCollisionsOrReply(client *c, robj **prefix, size_t numprefix); /* List data type */ @@ -3885,8 +3896,12 @@ int serverPubsubShardSubscriptionCount(void); size_t pubsubMemOverhead(client *c); void unmarkClientAsPubSub(client *c); int pubsubTotalSubscriptions(void); -dict *getClientPubSubChannels(client *c); -dict *getClientPubSubShardChannels(client *c); +int clientSubscriptionsCount(client *c); +int clientShardSubscriptionsCount(client *c); +int clientTotalPubSubSubscriptionCount(client *c); +void pubsubRekeySubscriptionsForACLLoad(client *c); +int pubsubUserIsNoAuth(user *u); +int pubsubUserSubsIsEmpty(pubsubUserSubs *subs); /* Keyspace events notification */ void notifyKeyspaceEvent(int type, const char *event, robj *key, int dbid); diff --git a/src/tracking.c b/src/tracking.c index c235d5812a..615e1a8bf5 100644 --- a/src/tracking.c +++ b/src/tracking.c @@ -9,6 +9,7 @@ */ #include "server.h" +#include "vector.h" /* The tracking table is constituted by a radix tree of keys, each pointing * to a radix tree of client IDs, used to track the clients that may have @@ -38,12 +39,15 @@ typedef struct bcastState { prefix. */ } bcastState; -/* Remove the tracking state from the client 'c'. Note that there is not much - * to do for us here, if not to decrement the counter of the clients in - * tracking mode, because we just store the ID of the client in the tracking - * table, so we'll remove the ID reference in a lazy way. Otherwise when a - * client with many entries in the table is removed, it would cost a lot of - * time to do the cleanup. */ +/* Remove the tracking state from the client 'c'. + * + * For BCAST mode, the client is immediately removed from its per-user + * vector in every prefix it subscribes to, and empty user/prefix entries + * are freed. + * + * For normal (non-BCAST) tracking, the client's ID references in the + * tracking table are removed lazily to avoid expensive cleanup when a + * client with many cached keys disconnects. */ void disableTracking(client *c) { /* If this client is in broadcasting mode, we need to unsubscribe it * from all the prefixes it is registered to. */ @@ -56,7 +60,23 @@ void disableTracking(client *c) { int found = raxFind(PrefixTable,ri.key,ri.key_len,&result); serverAssert(found); bcastState *bs = result; - raxRemove(bs->clients,(unsigned char*)&c,sizeof(c),NULL); + + /* Find the user vector and swap-remove this client from it. */ + vec *user_clients; + found = raxFind(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + (void**)&user_clients); + serverAssert(found); + ssize_t idx = vecIndexOf(user_clients, c); + serverAssert(idx >= 0); + vecSwapRemoveAt(user_clients, idx); + if (vecSize(user_clients) == 0) { + vecRelease(user_clients); + zfree(user_clients); + raxRemove(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), NULL); + } + /* Was it the last client? Remove the prefix from the * table. */ if (raxSize(bs->clients) == 0) { @@ -134,7 +154,7 @@ int checkPrefixCollisionsOrReply(client *c, robj **prefixes, size_t numprefix) { /* Set the client 'c' to track the prefix 'prefix'. If the client 'c' is * already registered for the specified prefix, no operation is performed. */ -void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { +static void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { void *result; bcastState *bs; /* If this is the first client subscribing to such prefix, create @@ -147,7 +167,22 @@ void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { } else { bs = result; } - if (raxTryInsert(bs->clients,(unsigned char*)&c,sizeof(c),NULL,NULL)) { + + /* Find or create the per-user client vector. */ + vec *user_clients; + if (!raxFind(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + (void**)&user_clients)) + { + user_clients = zmalloc(sizeof(vec)); + vecInit(user_clients, NULL, 0); + raxInsert(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + user_clients, NULL); + } + + if (vecIndexOf(user_clients, c) < 0) { + vecPush(user_clients, c); if (c->client_tracking_prefixes == NULL) c->client_tracking_prefixes = raxNew(); raxInsert(c->client_tracking_prefixes, @@ -552,28 +587,30 @@ void trackingLimitUsedSlots(void) { } /* Generate Redis protocol for an array containing all the key names - * in the 'keys' radix tree. If the client is not NULL, the list will not - * include keys that were modified the last time by this client, in order - * to implement the NOLOOP option. + * in the 'keys' radix tree, filtered by ACL permissions of user 'u' and + * optionally by NOLOOP (skipping keys last modified by 'noloop_client'). + * + * If 'u' is non-NULL, keys the user is not permitted to observe are excluded. + * If 'c' is non-NULL, keys whose last modifier (ri.data) matches + * that client are excluded. * * If the resulting array would be empty, NULL is returned instead. */ -sds trackingBuildBroadcastReply(client *c, rax *keys) { +sds trackingBuildBroadcastReply(user *u, client *c, rax *keys) { + debugServerAssert(!c || c->flags & CLIENT_TRACKING_NOLOOP); raxIterator ri; - uint64_t count; + uint64_t count = 0; - if (c == NULL) { - count = raxSize(keys); - } else { - count = 0; - raxStart(&ri,keys); - raxSeek(&ri,"^",NULL,0); - while(raxNext(&ri)) { - if (ri.data != c) count++; - } - raxStop(&ri); - - if (count == 0) return NULL; + raxStart(&ri, keys); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + if (c && ri.data == c) continue; + if (u && ACLUserCheckKeyPerm(u, (const char*)ri.key, ri.key_len, + ACL_READ_PERMISSION) != ACL_OK) continue; + count++; } + raxStop(&ri); + + if (count == 0) return NULL; /* Create the array reply with the list of keys once, then send * it to all the clients subscribed to this prefix. */ @@ -588,6 +625,8 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) { raxSeek(&ri,"^",NULL,0); while(raxNext(&ri)) { if (c && ri.data == c) continue; + if (u && ACLUserCheckKeyPerm(u, (const char*)ri.key, ri.key_len, + ACL_READ_PERMISSION) != ACL_OK) continue; len = ll2string(buf,sizeof(buf),ri.key_len); proto = sdscatlen(proto,"$",1); proto = sdscatlen(proto,buf,len); @@ -599,11 +638,127 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) { return proto; } +/* Send pending BCAST invalidation messages for a single prefix's + * bcastState, then reset bs->keys. Iterates user buckets, builds + * one proto per user, and sends to each client in the bucket. */ +static void trackingBcastInvalidationsForPrefix(bcastState *bs) { + if (raxSize(bs->keys) == 0) return; + + raxIterator ri; + raxStart(&ri, bs->clients); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + user *u; + memcpy(&u, ri.key, sizeof(u)); + vec *user_clients = ri.data; + + sds proto = trackingBuildBroadcastReply(u, NULL, bs->keys); + + for (size_t j = 0; j < vecSize(user_clients); j++) { + client *c = vecGet(user_clients, j); + + if (c->flags & CLIENT_TRACKING_NOLOOP) { + sds adhoc = trackingBuildBroadcastReply(u, c, bs->keys); + if (!adhoc) continue; + sendTrackingMessage(c, adhoc, + sdslen(adhoc), 1); + sdsfree(adhoc); + continue; + } + if (!proto) continue; + + sendTrackingMessage(c, proto, sdslen(proto), 1); + } + + sdsfree(proto); + } + raxStop(&ri); + + raxFree(bs->keys); + bs->keys = raxNew(); +} + +/* Send pending BCAST invalidation messages for every prefix in + * 'prefixes' (a rax of prefix -> NULL, i.e. client_tracking_prefixes). + * This triggers the full broadcast cycle for each matching prefix. */ +static void trackingBcastSendInvalidationsForPrefixes(rax *prefixes) { + raxIterator ri; + raxStart(&ri, prefixes); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + void *result; + int found = raxFind(PrefixTable, ri.key, ri.key_len, &result); + serverAssert(found); + trackingBcastInvalidationsForPrefix(result); + } + raxStop(&ri); +} + +/* Move client 'c' from its old user bucket (keyed by 'old_user') to + * the bucket for c->user in every bcastState the client subscribes to. + * Must be called AFTER c->user is updated. */ +static void trackingBcastMoveClient(client *c, user *old_user) { + user *new_user = c->user; + raxIterator ri; + raxStart(&ri, c->client_tracking_prefixes); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + void *result; + int found = raxFind(PrefixTable, ri.key, ri.key_len, &result); + serverAssert(found); + bcastState *bs = result; + + /* Swap-remove from old user vector. */ + vec *from_clients; + found = raxFind(bs->clients, + (unsigned char*)&old_user, sizeof(old_user), + (void**)&from_clients); + serverAssert(found); + ssize_t idx = vecIndexOf(from_clients, c); + serverAssert(idx >= 0); + vecSwapRemoveAt(from_clients, idx); + if (vecSize(from_clients) == 0) { + vecRelease(from_clients); + zfree(from_clients); + raxRemove(bs->clients, + (unsigned char*)&old_user, sizeof(old_user), NULL); + } + + /* Insert into new user vector. */ + vec *to_clients; + if (!raxFind(bs->clients, + (unsigned char*)&new_user, sizeof(new_user), + (void**)&to_clients)) + { + to_clients = zmalloc(sizeof(vec)); + vecInit(to_clients, NULL, 0); + raxInsert(bs->clients, + (unsigned char*)&new_user, sizeof(new_user), + to_clients, NULL); + } + vecPush(to_clients, c); + } + raxStop(&ri); +} + +/* Handle a BCAST tracking client after a user change: flush all pending + * invalidation messages for its prefixes (so every subscriber receives + * them under the previous ACL identity), then move the client from the + * 'old_user' bucket to the bucket for c->user. + * Must be called AFTER c->user is updated. */ +void trackingBroadcastPostUserSwitch(client *c, user *old_user) { + if (!(c->flags & CLIENT_TRACKING_BCAST)) return; + if (c->user == old_user) return; + + trackingBcastSendInvalidationsForPrefixes(c->client_tracking_prefixes); + trackingBcastMoveClient(c, old_user); +} + /* This function will run the prefixes of clients in BCAST mode and * keys that were modified about each prefix, and will send the * notifications to each client in each prefix. */ void trackingBroadcastInvalidationMessages(void) { - raxIterator ri, ri2; + raxIterator ri; /* Return ASAP if there is nothing to do here. */ if (TrackingTable == NULL || !server.tracking_clients) return; @@ -611,41 +766,8 @@ void trackingBroadcastInvalidationMessages(void) { raxStart(&ri,PrefixTable); raxSeek(&ri,"^",NULL,0); - /* For each prefix... */ while(raxNext(&ri)) { - bcastState *bs = ri.data; - - if (raxSize(bs->keys)) { - /* Generate the common protocol for all the clients that are - * not using the NOLOOP option. */ - sds proto = trackingBuildBroadcastReply(NULL,bs->keys); - - /* Send this array of keys to every client in the list. */ - raxStart(&ri2,bs->clients); - raxSeek(&ri2,"^",NULL,0); - while(raxNext(&ri2)) { - client *c; - memcpy(&c,ri2.key,sizeof(c)); - if (c->flags & CLIENT_TRACKING_NOLOOP) { - /* This client may have certain keys excluded. */ - sds adhoc = trackingBuildBroadcastReply(c,bs->keys); - if (adhoc) { - sendTrackingMessage(c,adhoc,sdslen(adhoc),1); - sdsfree(adhoc); - } - } else { - sendTrackingMessage(c,proto,sdslen(proto),1); - } - } - raxStop(&ri2); - - /* Clean up: we can remove everything from this state, because we - * want to only track the new keys that will be accumulated starting - * from now. */ - sdsfree(proto); - } - raxFree(bs->keys); - bs->keys = raxNew(); + trackingBcastInvalidationsForPrefix(ri.data); } raxStop(&ri); } diff --git a/src/vector.c b/src/vector.c index fc0ba13e16..70f080dbd1 100644 --- a/src/vector.c +++ b/src/vector.c @@ -100,6 +100,21 @@ void vecPush(vec *v, void *value) { v->data[v->size++] = value; } +/* Return the index of the first occurrence of 'elem', or -1 if not found. */ +ssize_t vecIndexOf(const vec *v, void *elem) { + for (size_t i = 0; i < v->size; i++) { + if (v->data[i] == elem) return (ssize_t)i; + } + return -1; +} + +/* Remove the element at 'index' by swapping with the last element. + * Does not invoke the free callback. Requires index < vecSize(v). */ +void vecSwapRemoveAt(vec *v, size_t index) { + assert(index < v->size); + v->data[index] = v->data[--v->size]; +} + #ifdef REDIS_TEST #include @@ -221,6 +236,49 @@ int vectorTest(int argc, char **argv, int flags) vecRelease(&v); test_cond("vecRelease() free method is a no-op on empty vector", vecTestFreeCalls == 0); + /* vecIndexOf tests */ + vecInit(&v, NULL, 0); + test_cond("vecIndexOf() returns -1 on empty vector", + vecIndexOf(&v, &one) == -1); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + test_cond("vecIndexOf() finds first element", + vecIndexOf(&v, &one) == 0); + test_cond("vecIndexOf() finds middle element", + vecIndexOf(&v, &two) == 1); + test_cond("vecIndexOf() finds last element", + vecIndexOf(&v, &three) == 2); + test_cond("vecIndexOf() returns -1 for missing element", + vecIndexOf(&v, &four) == -1); + vecRelease(&v); + + /* vecSwapRemoveAt tests */ + vecInit(&v, NULL, 0); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + vecSwapRemoveAt(&v, 1); + test_cond("vecSwapRemoveAt() removes middle element and swaps with last", + vecSize(&v) == 2 && + vecGet(&v, 0) == &one && vecGet(&v, 1) == &three); + vecSwapRemoveAt(&v, 1); + test_cond("vecSwapRemoveAt() removes last element", + vecSize(&v) == 1 && vecGet(&v, 0) == &one); + vecSwapRemoveAt(&v, 0); + test_cond("vecSwapRemoveAt() removes sole element", + vecSize(&v) == 0); + vecRelease(&v); + + vecInit(&v, NULL, 0); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + vecSwapRemoveAt(&v, 0); + test_cond("vecSwapRemoveAt() removes first element and swaps with last", + vecSize(&v) == 2 && + vecGet(&v, 0) == &three && vecGet(&v, 1) == &two); + vecRelease(&v); return 0; } diff --git a/src/vector.h b/src/vector.h index c89955c987..c89465567d 100644 --- a/src/vector.h +++ b/src/vector.h @@ -2,6 +2,7 @@ #define REDIS_VECTOR_H #include +#include /* * Simple append-only vector (dynamic array) of void * elements. @@ -96,6 +97,13 @@ void vecReserve(vec *v, size_t mincap); /* Append one element, growing storage as needed. */ void vecPush(vec *v, void *value); +/* Return the index of the first occurrence of 'elem', or -1 if not found. */ +ssize_t vecIndexOf(const vec *v, void *elem); + +/* Remove the element at 'index' by swapping with the last element. + * Does not invoke the free callback. Requires index < vecSize(v). */ +void vecSwapRemoveAt(vec *v, size_t index); + #ifdef REDIS_TEST int vectorTest(int argc, char **argv, int flags); #endif diff --git a/tests/cluster/tests/25-pubsubshard-slot-migration.tcl b/tests/cluster/tests/25-pubsubshard-slot-migration.tcl index fd774a8d7b..c7d196a36e 100644 --- a/tests/cluster/tests/25-pubsubshard-slot-migration.tcl +++ b/tests/cluster/tests/25-pubsubshard-slot-migration.tcl @@ -189,6 +189,44 @@ test "Delete a slot, verify sunsubscribe message" { $subscribeclient close } +test "Migrate a slot with multi-user shard subscriptions, verify sunsubscribe is delivered correctly" { + set channelname ch5 + set slot [$cluster cluster keyslot $channelname] + array set nodefrom [$cluster masternode_for_slot $slot] + array set nodeto [$cluster masternode_notfor_slot $slot] + + $nodefrom(link) ACL SETUSER slotuser on nopass ~* &* +@all + + set subscribeclient [redis_deferring_client_by_addr $nodefrom(host) $nodefrom(port)] + $subscribeclient deferred 1 + + $subscribeclient hello 3 AUTH slotuser slotuser + $subscribeclient read + + $subscribeclient ssubscribe $channelname + $subscribeclient read + + $subscribeclient auth default "" + $subscribeclient read + + $nodefrom(link) spublish $channelname pre-migrate + assert_equal "smessage $channelname pre-migrate" [$subscribeclient read] + + assert_equal {OK} [$nodefrom(link) cluster setslot $slot migrating $nodeto(id)] + assert_equal {OK} [$nodeto(link) cluster setslot $slot importing $nodefrom(id)] + assert_equal {OK} [$nodefrom(link) cluster setslot $slot node $nodeto(id)] + + set msg [$subscribeclient read] + assert {"sunsubscribe" eq [lindex $msg 0]} + assert {$channelname eq [lindex $msg 1]} + assert {"0" eq [lindex $msg 2]} + + assert_equal {OK} [$nodeto(link) cluster setslot $slot node $nodeto(id)] + + $subscribeclient close + $nodefrom(link) ACL DELUSER slotuser +} + test "Reset cluster, verify sunsubscribe message" { set channelname ch4 set slot [$cluster cluster keyslot $channelname] diff --git a/tests/unit/acl.tcl b/tests/unit/acl.tcl index 77bb37095f..123d59653d 100644 --- a/tests/unit/acl.tcl +++ b/tests/unit/acl.tcl @@ -343,6 +343,268 @@ start_server {tags {"acl external:skip"}} { $rd close } {0} + # ─── Provenance: subscription revocation across re-auth ─── + + test {Provenance: channel subscription is killed when originating user's permissions are revoked} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE secret + assert_match {subscribe secret 1} [$rd read] + # Re-auth as default — subscription stays under provuser + $rd AUTH default "" + $rd read + $rd CLIENT SETNAME prov-channel + $rd read + # Revoke provuser's channel access + r ACL SETUSER provuser resetchannels + # Client must be killed — provenance entry is under provuser + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*prov-channel*} [r CLIENT LIST] + $rd close + r ACL DELUSER provuser + } + + test {Provenance: pattern subscription is killed when originating user's permissions are revoked} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd PSUBSCRIBE secret:* + assert_match {psubscribe secret:* 1} [$rd read] + $rd AUTH default "" + $rd read + $rd CLIENT SETNAME prov-pattern + $rd read + r ACL SETUSER provuser resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*prov-pattern*} [r CLIENT LIST] + $rd close + r ACL DELUSER provuser + } + + test {Provenance: shard channel subscription is killed when originating user's permissions are revoked} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SSUBSCRIBE secret + assert_match {ssubscribe secret 1} [$rd read] + $rd AUTH default "" + $rd read + $rd CLIENT SETNAME prov-shard + $rd read + r ACL SETUSER provuser resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*prov-shard*} [r CLIENT LIST] + $rd close + r ACL DELUSER provuser + } + + # ─── Provenance: ACL DELUSER on originating user ─── + + test {Provenance: ACL DELUSER kills client that holds subscriptions from deleted user} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE chan1 + $rd read + # Re-auth as default, subscription remains under provuser + $rd AUTH default "" + $rd read + $rd CLIENT SETNAME prov-deluser + $rd read + r ACL DELUSER provuser + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*prov-deluser*} [r CLIENT LIST] + $rd close + } {0} + + # ─── Provenance: duplicate subscribe after re-auth (first user wins) ─── + + test {Provenance: duplicate subscribe after re-auth attributes to first user} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE chan1 + assert_match {subscribe chan1 1} [$rd read] + # Re-auth and subscribe to the same channel — should be a no-op + $rd AUTH default "" + $rd read + $rd SUBSCRIBE chan1 + assert_match {subscribe chan1 1} [$rd read] + $rd CLIENT SETNAME prov-dup + $rd read + # Revoke provuser (the originating user) — must kill + r ACL SETUSER provuser resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*prov-dup*} [r CLIENT LIST] + $rd close + r ACL DELUSER provuser + } + + # ─── Provenance: many user switches on one connection ─── + + test {Provenance: many user switches with subscriptions, revoking one kills client} { + r ACL SETUSER user1 on nopass ~* &* +@all + r ACL SETUSER user2 on nopass ~* &* +@all + r ACL SETUSER user3 on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH user1 user1 + $rd read + $rd SUBSCRIBE ch1 + $rd read + $rd AUTH user2 user2 + $rd read + $rd SUBSCRIBE ch2 + $rd read + $rd AUTH user3 user3 + $rd read + $rd SUBSCRIBE ch3 + $rd read + $rd CLIENT SETNAME multi-user + $rd read + # Verify all subscriptions deliver + r PUBLISH ch1 msg1 + assert_match {*msg1*} [$rd read] + r PUBLISH ch2 msg2 + assert_match {*msg2*} [$rd read] + r PUBLISH ch3 msg3 + assert_match {*msg3*} [$rd read] + # Revoke user2 — client must be killed + r ACL SETUSER user2 resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_no_match {*multi-user*} [r CLIENT LIST] + $rd close + r ACL DELUSER user1 user2 user3 + } + + # ─── Lifecycle: RESET after multi-user subscribe ─── + + test {Provenance: RESET clears all per-user subscription state} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE ch1 + $rd read + $rd AUTH default "" + $rd read + $rd SUBSCRIBE ch2 + $rd read + # RESET should clear everything + $rd RESET + $rd read + # Client should be out of pubsub mode — normal commands should work + $rd SET testkey testval + assert_match {OK} [$rd read] + $rd DEL testkey + $rd read + # PUBSUB NUMSUB should show zero for both channels + assert_equal {ch1 0 ch2 0} [r PUBSUB NUMSUB ch1 ch2] + $rd close + r ACL DELUSER provuser + } + + # ─── Lifecycle: unsubscribe-all after multi-user subscribe ─── + + test {Provenance: UNSUBSCRIBE with no args clears all per-user channel entries} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE ch1 ch2 + $rd read ; # subscribe ch1 + $rd read ; # subscribe ch2 + $rd AUTH default "" + $rd read + $rd SUBSCRIBE ch3 + $rd read + # Unsubscribe all channels (no args) + $rd UNSUBSCRIBE + $rd read ; # unsubscribe ch1 + $rd read ; # unsubscribe ch2 + $rd read ; # unsubscribe ch3 + assert_equal {ch1 0 ch2 0 ch3 0} [r PUBSUB NUMSUB ch1 ch2 ch3] + $rd close + r ACL DELUSER provuser + } + + test {Provenance: PUNSUBSCRIBE with no args clears all per-user pattern entries} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd PSUBSCRIBE foo:* + $rd read + $rd AUTH default "" + $rd read + $rd PSUBSCRIBE bar:* + $rd read + $rd PUNSUBSCRIBE + $rd read ; # punsubscribe foo:* + $rd read ; # punsubscribe bar:* + assert_equal {0} [r PUBSUB NUMPAT] + $rd close + r ACL DELUSER provuser + } + + # ─── PUBSUB NUMSUB/NUMPAT correctness after provenance operations ─── + + test {Provenance: PUBSUB NUMSUB stays correct through subscribe, re-auth, and revocation} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd SUBSCRIBE ch1 + $rd read + assert_equal {ch1 1} [r PUBSUB NUMSUB ch1] + $rd AUTH default "" + $rd read + $rd SUBSCRIBE ch2 + $rd read + assert_equal {ch1 1 ch2 1} [r PUBSUB NUMSUB ch1 ch2] + # Revoke provuser — client killed, all subscriptions gone + r ACL SETUSER provuser resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_equal {ch1 0 ch2 0} [r PUBSUB NUMSUB ch1 ch2] + $rd close + r ACL DELUSER provuser + } + + test {Provenance: PUBSUB NUMPAT stays correct through subscribe, re-auth, and revocation} { + r ACL SETUSER provuser on nopass ~* &* +@all + set rd [redis_deferring_client] + $rd HELLO 3 AUTH provuser provuser + $rd read + $rd PSUBSCRIBE foo:* + $rd read + assert_equal {1} [r PUBSUB NUMPAT] + $rd AUTH default "" + $rd read + $rd PSUBSCRIBE bar:* + $rd read + assert_equal {2} [r PUBSUB NUMPAT] + r ACL SETUSER provuser resetchannels + catch {$rd read} e + assert_match {*I/O error*} $e + assert_equal {0} [r PUBSUB NUMPAT] + $rd close + r ACL DELUSER provuser + } + + # ─── End of provenance tests ─── + test {blocked command gets rejected when reprocessed after permission change} { r auth default "" r config resetstat @@ -1125,6 +1387,159 @@ start_server [list overrides [list "dir" $server_path "acl-pubsub-default" "allc $rd2 close } + test {ACL LOAD re-keys surviving client subscriptions to new user pointers} { + reconnect + + set rd1 [redis_deferring_client] + $rd1 AUTH alice alice + $rd1 read + $rd1 CLIENT SETNAME rekey-survivor + $rd1 read + $rd1 SUBSCRIBE test1 + $rd1 read + + # alice is unchanged in the ACL file — should survive and + # subscriptions should still work after re-keying to the new + # user pointer. + r ACL LOAD + r PUBLISH test1 rekey-msg + + assert_match {*rekey-msg*} [$rd1 read] + assert_match {*rekey-survivor*} [r CLIENT LIST] + $rd1 close + } + + test {ACL LOAD re-keyed client can create new subscriptions} { + reconnect + + set rd1 [redis_deferring_client] + $rd1 HELLO 3 AUTH alice alice + $rd1 read + $rd1 SUBSCRIBE test1 + $rd1 read + + r ACL LOAD + + # After re-keying, the client should be able to subscribe to new + # channels under the same (now re-keyed) user pointer. + $rd1 SUBSCRIBE test2 + $rd1 read + + # Both the old (re-keyed) and new subscriptions should deliver. + r PUBLISH test1 old-channel + assert_match {*old-channel*} [$rd1 read] + r PUBLISH test2 new-channel + assert_match {*new-channel*} [$rd1 read] + + # Second ACL LOAD: forces the walk to dereference every outer dict + # key (old_user_ptr->name). If the first re-key left a dangling + # pointer, this will crash rather than pass silently. + r ACL LOAD + r PUBLISH test1 after-second-load + assert_match {*after-second-load*} [$rd1 read] + $rd1 close + } + + test {ACL LOAD kills client when one of multiple provenance users is deleted} { + reconnect + r ACL SETUSER tempuser on nopass ~* &* +@all + + set rd1 [redis_deferring_client] + # Subscribe as alice first, then re-auth as tempuser and subscribe more + $rd1 HELLO 3 AUTH alice alice + $rd1 read + $rd1 SUBSCRIBE test1 + $rd1 read + $rd1 AUTH tempuser tempuser + $rd1 read + $rd1 SUBSCRIBE test2 + $rd1 read + $rd1 AUTH alice alice + $rd1 read + $rd1 CLIENT SETNAME multi-prov + $rd1 read + + # tempuser is not in user.acl, so ACL LOAD will delete it. + # Client has a provenance entry for the deleted user → must be killed. + r ACL LOAD + catch {$rd1 read} e + assert_match {*I/O error*} $e + assert_no_match {*multi-prov*} [r CLIENT LIST] + $rd1 close + } + + test {ACL LOAD kills default user subscriber when channel access revoked} { + reconnect + set rd1 [redis_deferring_client] + $rd1 CLIENT SETNAME default-sub + $rd1 read + $rd1 SUBSCRIBE secret + $rd1 read + + # Write a modified ACL file that restricts default's channels + set aclfile [file join $server_path user.acl] + set fd [open $aclfile w] + puts $fd "user alice on allcommands allkeys &* >alice" + puts $fd "user bob on -@all +@set +acl ~set* &* >bob" + puts $fd "user doug on resetchannels &test +@all ~* >doug" + puts $fd "user default on nopass ~* resetchannels &healthcheck +@all" + close $fd + + r ACL LOAD + + # default no longer has access to "secret" → client must be killed + catch {$rd1 read} e + assert_match {*I/O error*} $e + assert_no_match {*default-sub*} [r CLIENT LIST] + + # Restore the original ACL file + exec cp -f tests/assets/user.acl $server_path + r ACL LOAD + $rd1 close + } + + test {ACL LOAD default user subscriber survives when permissions unchanged} { + reconnect + set rd1 [redis_deferring_client] + $rd1 CLIENT SETNAME default-survivor + $rd1 read + $rd1 SUBSCRIBE test1 + $rd1 read + + # Reload with identical permissions — default pointer is stable, + # subscriptions should survive. + r ACL LOAD + r PUBLISH test1 default-ok + + assert_match {*default-ok*} [$rd1 read] + assert_match {*default-survivor*} [r CLIENT LIST] + $rd1 close + } + + test {Pointer-key optimization: long username does not bloat subscription memory} { + reconnect + set longname [string repeat "A" 1000000] + r ACL SETUSER $longname on nopass ~* &* +@all + + set rd1 [redis_deferring_client] + $rd1 AUTH $longname $longname + $rd1 read + + set mem_before [s used_memory] + $rd1 SUBSCRIBE ch1 + $rd1 read + set mem_after [s used_memory] + + # The subscription should add dict overhead + pubsubUserSubs (~hundreds + # of bytes), not a copy of the 1MB username. Allow 64KB slack for other + # allocations that may happen concurrently. + set delta [expr {$mem_after - $mem_before}] + assert {$delta < 65536} + + $rd1 close + r ACL DELUSER $longname + } + test {ACL load and save} { r ACL setuser eve +get allkeys >eve on r ACL save diff --git a/tests/unit/memefficiency.tcl b/tests/unit/memefficiency.tcl index f488ca85f3..876b0ed397 100644 --- a/tests/unit/memefficiency.tcl +++ b/tests/unit/memefficiency.tcl @@ -582,6 +582,124 @@ run_solo {defrag} { $rd_pubsub close } + test "Active defrag pubsub multi-user subscriptions: $type" { + # This test verifies that active defrag correctly handles the + # two-level pubsub_subscriptions dict (outer dict keyed by user*, + # inner dicts for channels/patterns/shard_channels). Two clients + # subscribe under different ACL users so the outer dict has + # multiple entries, exercising the full defrag iteration path. + + r flushdb + r config set hz 100 + r config set activedefrag no + wait_for_defrag_stop 500 100 + r config resetstat + r config set active-defrag-threshold-lower 5 + r config set active-defrag-cycle-min 65 + r config set active-defrag-cycle-max 75 + r config set active-defrag-ignore-bytes 1500kb + r config set maxmemory 0 + + # Create a second ACL user so we have two distinct user* keys + # in each client's pubsub_subscriptions dict. + r ACL SETUSER defraguser on nopass ~* &* +@all + + set n 25000 + set dummy_channel "[string repeat x 400]" + set rd_default [redis_deferring_client] + set rd_extra [redis_deferring_client] + $rd_extra AUTH defraguser defraguser + $rd_extra read + + # Subscribe to 25k channels, alternating between the two clients. + # After each subscription, create a filler key of similar size via + # SETBIT. This interleaves subscription allocations with filler + # allocations in memory, which is needed to create fragmentation + # when the fillers are deleted later. + set rd_filler [redis_deferring_client] + for {set j 0} {$j < $n} {incr j} { + set channel_name "$dummy_channel[format "%06d" $j]" + if {$j % 2 == 0} { + $rd_default subscribe $channel_name + $rd_default read + } else { + $rd_extra subscribe $channel_name + $rd_extra read + } + # Create a ~400 byte filler key interleaved with subscription allocs + $rd_filler setbit k$j [expr {[string length $channel_name] * 8}] 1 + $rd_filler read + } + + # Sanity: fragmentation should be low right after populating + after 120 ;# serverCron only updates the info once in 100ms + assert_lessthan [s allocator_frag_ratio] 1.1 + + # Delete all filler keys to punch holes in memory and create + # fragmentation. Use batching to avoid TCP deadlock. + set batch_size 1000 + for {set j 0} {$j < $n} {incr j} { + $rd_filler del k$j + if {($j + 1) % $batch_size == 0} { + for {set i 0} {$i < $batch_size} {incr i} { + $rd_filler read + } + } + } + set remaining [expr {$n % $batch_size}] + for {set j 0} {$j < $remaining} {incr j} { $rd_filler read } + if {$type eq "cluster"} { + $rd_filler config resetstat + $rd_filler read + } + $rd_filler close + + # Verify fragmentation is high enough for defrag to kick in + after 120 ;# serverCron only updates the info once in 100ms + assert_morethan [s allocator_frag_ratio] 1.35 + + # Enable active defrag and wait for it to compact memory + catch {r config set activedefrag yes} e + if {[r config get activedefrag] eq "activedefrag yes"} { + # Wait for defrag to start working (decision once a second) + wait_for_condition 50 100 { + [s total_active_defrag_time] ne 0 + } else { + after 120 + puts [r info memory] + puts [r info stats] + puts [r memory malloc-stats] + fail "defrag not started." + } + + # Wait for defrag to finish and verify fragmentation dropped + wait_for_defrag_stop 500 100 1.1 + + after 120 ;# serverCron only updates the info once in 100ms + } + + # Verify data integrity: publish to every channel and confirm the + # correct client receives the message. If defrag corrupted any + # channel name, dict pointer, or subscription structure, this will + # fail or crash the server. + for {set j 0} {$j < $n} {incr j} { + set channel "$dummy_channel[format "%06d" $j]" + r publish $channel "hello" + if {$j % 2 == 0} { + assert_equal "message $channel hello" [$rd_default read] + $rd_default unsubscribe $channel + $rd_default read + } else { + assert_equal "message $channel hello" [$rd_extra read] + $rd_extra unsubscribe $channel + $rd_extra read + } + } + $rd_default close + $rd_extra close + r ACL DELUSER defraguser + } + test "Active defrag IDMP streams: $type" { r flushdb r config set hz 100 diff --git a/tests/unit/tracking.tcl b/tests/unit/tracking.tcl index 174575eee9..14d9daab24 100644 --- a/tests/unit/tracking.tcl +++ b/tests/unit/tracking.tcl @@ -883,6 +883,139 @@ start_server {tags {"tracking network logreqres:skip"}} { assert_equal {PONG} [$rd read] } + test {BCAST ACL filtering - two clients same user see only permitted keys} { + clean_all + + r ACL SETUSER shareduser on >pass123 ~public:* +@all + set c1 [redis_deferring_client] + set c2 [redis_deferring_client] + + $c1 AUTH shareduser pass123 + $c1 read + + $c2 AUTH shareduser pass123 + $c2 read + + $c1 HELLO 3 + $c1 read + $c2 HELLO 3 + $c2 read + + $c1 CLIENT TRACKING on BCAST PREFIX public: PREFIX admin: + assert_match {*OK*} [$c1 read] + $c2 CLIENT TRACKING on BCAST PREFIX public: PREFIX admin: + assert_match {*OK*} [$c2 read] + + $rd_sg MSET public:a{t} 1 admin:b{t} 2 + + # Both clients should receive exactly {public:a{t}} for the + # public: prefix, and nothing for admin: (filtered out by ACL). + set c1_keys {} + set c2_keys {} + # Read invalidation messages: there are two prefixes, but only + # public: should have data for shareduser. + after 100 + # $rd_sg is synchronous, so modified keys are already recorded + # on the server by the time we send PING. BCAST invalidations + # are flushed in beforeSleep before PONG, so they precede it + # on the wire. Drain all push messages until we hit the PONG. + $c1 PING + while 1 { + set resp [$c1 read] + if {[lindex $resp 0] eq "invalidate"} { + lappend c1_keys {*}[lindex $resp 1] + } else { + break + } + } + $c2 PING + while 1 { + set resp [$c2 read] + if {[lindex $resp 0] eq "invalidate"} { + lappend c2_keys {*}[lindex $resp 1] + } else { + break + } + } + + assert_equal [lsort $c1_keys] [list public:a{t}] + assert_equal [lsort $c2_keys] [list public:a{t}] + + $c1 CLIENT TRACKING off + $c1 read + $c2 CLIENT TRACKING off + $c2 read + $c1 close + $c2 close + r ACL DELUSER shareduser + } + + test {BCAST re-AUTH re-buckets correctly with ACL filtering} { + clean_all + + r ACL SETUSER usr_a on >passA ~a:* +@all + r ACL SETUSER usr_b on >passB ~b:* +@all + + set tc [redis_deferring_client] + $tc AUTH usr_a passA + $tc read + + $tc HELLO 3 + $tc read + + $tc CLIENT TRACKING on BCAST PREFIX a: PREFIX b: + assert_match {*OK*} [$tc read] + + # Write keys matching both prefixes. + $rd_sg SET a:1{t} val1 + $rd_sg SET b:1{t} val1 + + # Under usr_a, only a:* is visible. + # $rd_sg is synchronous, so modified keys are already recorded + # on the server by the time we send PING. BCAST invalidations + # are flushed in beforeSleep before PONG, so they precede it + # on the wire. Drain all push messages until we hit the PONG. + after 100 + $tc PING + set keys {} + while 1 { + set resp [$tc read] + if {[lindex $resp 0] eq "invalidate"} { + lappend keys {*}[lindex $resp 1] + } else { + break + } + } + assert_equal $keys [list a:1{t}] + + # Re-AUTH as usr_b. + $tc AUTH usr_b passB + $tc read + + # Write again. + $rd_sg SET a:2{t} val2 + $rd_sg SET b:2{t} val2 + + after 100 + $tc PING + set keys {} + while 1 { + set resp [$tc read] + if {[lindex $resp 0] eq "invalidate"} { + lappend keys {*}[lindex $resp 1] + } else { + break + } + } + assert_equal $keys [list b:2{t}] + + $tc CLIENT TRACKING off + $tc read + $tc close + r ACL DELUSER usr_a + r ACL DELUSER usr_b + } + $rd_redirection close $rd_sg close $rd close