diff --git a/src/acl.c b/src/acl.c index 177077d45..25b0a7d4c 100644 --- a/src/acl.c +++ b/src/acl.c @@ -523,6 +523,14 @@ void ACLCopyUser(user *dst, user *src) { } } +/* Set the user for a client, performing any necessary bookkeeping such as + * updating broadcast tracking state for the user switch. */ +void clientSetUser(client *c, user *new_user) { + user *old = c->user; + c->user = new_user; + trackingBroadcastPostUserSwitch(c, old); +} + /* Given a command ID, this function set by reference 'word' and 'bit' * so that user->allowed_commands[word] will address the right word * where the corresponding bit for the provided ID is stored, and @@ -1497,7 +1505,7 @@ void addAuthErrReply(client *c, robj *err) { int checkPasswordBasedAuth(client *c, robj *username, robj *password) { if (ACLCheckUserCredentials(username,password) == C_OK) { c->authenticated = 1; - c->user = ACLGetUserByName(username->ptr,sdslen(username->ptr)); + clientSetUser(c, ACLGetUserByName(username->ptr,sdslen(username->ptr))); moduleNotifyUserChanged(c); return AUTH_OK; } else { @@ -2481,7 +2489,7 @@ sds ACLLoadFromFile(const char *filename) { deauthenticateAndCloseClient(c); continue; } - c->user = new; + clientSetUser(c, new); } if (user_channels) @@ -3241,7 +3249,7 @@ static void internalAuth(client *c) { c->authenticated = 1; /* Set the user to the unrestricted user, if it is not already set (default). */ if (c->user != NULL) { - c->user = NULL; + clientSetUser(c, NULL); moduleNotifyUserChanged(c); } addReply(c, shared.ok); diff --git a/src/module.c b/src/module.c index 50a594987..6221816b4 100644 --- a/src/module.c +++ b/src/module.c @@ -10809,8 +10809,8 @@ static int authenticateClientWithUser(RedisModuleCtx *ctx, user *user, RedisModu moduleNotifyUserChanged(ctx->client); - ctx->client->user = user; ctx->client->authenticated = 1; + clientSetUser(ctx->client, user); if (clientHasModuleAuthInProgress(ctx->client)) { ctx->client->flags |= CLIENT_MODULE_AUTH_HAS_RESULT; diff --git a/src/networking.c b/src/networking.c index 2f5384c3b..ebac816ab 100644 --- a/src/networking.c +++ b/src/networking.c @@ -103,7 +103,7 @@ void linkClient(client *c) { static void clientSetDefaultAuth(client *c) { /* If the default user does not require authentication, the user is * directly authenticated. */ - c->user = DefaultUser; + clientSetUser(c, DefaultUser); c->authenticated = (c->user->flags & USER_FLAG_NOPASS) && !(c->user->flags & USER_FLAG_DISABLED); } @@ -193,6 +193,7 @@ client *createClient(connection *conn) { c->ctime = c->lastinteraction = server.unixtime; c->io_lastinteraction = 0; c->duration = 0; + c->user = DefaultUser; /* Set a safe default value: clientSetDefaultAuth reads c->user. */ clientSetDefaultAuth(c); c->replstate = REPL_STATE_NONE; c->repl_start_cmd_stream_on_ack = 0; @@ -1614,8 +1615,8 @@ void clientAcceptHandler(connection *conn) { if (username != NULL) { user *u = ACLGetUserByName(username, sdslen(username)); if (u && !(u->flags & USER_FLAG_DISABLED)) { - c->user = u; c->authenticated = 1; + clientSetUser(c, u); moduleNotifyUserChanged(c); serverLog(LL_VERBOSE, "TLS: Auto-authenticated client as %s", server.hide_user_data_from_log ? "*redacted*" : u->name); @@ -2073,6 +2074,7 @@ void clearClientConnectionState(client *c) { } void deauthenticateAndCloseClient(client *c) { + disableTracking(c); c->user = DefaultUser; c->authenticated = 0; /* We will write replies to this client later, so we can't diff --git a/src/server.h b/src/server.h index 9318eec68..4c232afef 100644 --- a/src/server.h +++ b/src/server.h @@ -3361,6 +3361,8 @@ uint64_t trackingGetTotalItems(void); uint64_t trackingGetTotalKeys(void); uint64_t trackingGetTotalPrefixes(void); void trackingBroadcastInvalidationMessages(void); +void trackingBroadcastPostUserSwitch(client *c, user *old_user); +void clientSetUser(client *c, user *new_user); int checkPrefixCollisionsOrReply(client *c, robj **prefix, size_t numprefix); /* List data type */ diff --git a/src/tracking.c b/src/tracking.c index c235d5812..615e1a8bf 100644 --- a/src/tracking.c +++ b/src/tracking.c @@ -9,6 +9,7 @@ */ #include "server.h" +#include "vector.h" /* The tracking table is constituted by a radix tree of keys, each pointing * to a radix tree of client IDs, used to track the clients that may have @@ -38,12 +39,15 @@ typedef struct bcastState { prefix. */ } bcastState; -/* Remove the tracking state from the client 'c'. Note that there is not much - * to do for us here, if not to decrement the counter of the clients in - * tracking mode, because we just store the ID of the client in the tracking - * table, so we'll remove the ID reference in a lazy way. Otherwise when a - * client with many entries in the table is removed, it would cost a lot of - * time to do the cleanup. */ +/* Remove the tracking state from the client 'c'. + * + * For BCAST mode, the client is immediately removed from its per-user + * vector in every prefix it subscribes to, and empty user/prefix entries + * are freed. + * + * For normal (non-BCAST) tracking, the client's ID references in the + * tracking table are removed lazily to avoid expensive cleanup when a + * client with many cached keys disconnects. */ void disableTracking(client *c) { /* If this client is in broadcasting mode, we need to unsubscribe it * from all the prefixes it is registered to. */ @@ -56,7 +60,23 @@ void disableTracking(client *c) { int found = raxFind(PrefixTable,ri.key,ri.key_len,&result); serverAssert(found); bcastState *bs = result; - raxRemove(bs->clients,(unsigned char*)&c,sizeof(c),NULL); + + /* Find the user vector and swap-remove this client from it. */ + vec *user_clients; + found = raxFind(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + (void**)&user_clients); + serverAssert(found); + ssize_t idx = vecIndexOf(user_clients, c); + serverAssert(idx >= 0); + vecSwapRemoveAt(user_clients, idx); + if (vecSize(user_clients) == 0) { + vecRelease(user_clients); + zfree(user_clients); + raxRemove(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), NULL); + } + /* Was it the last client? Remove the prefix from the * table. */ if (raxSize(bs->clients) == 0) { @@ -134,7 +154,7 @@ int checkPrefixCollisionsOrReply(client *c, robj **prefixes, size_t numprefix) { /* Set the client 'c' to track the prefix 'prefix'. If the client 'c' is * already registered for the specified prefix, no operation is performed. */ -void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { +static void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { void *result; bcastState *bs; /* If this is the first client subscribing to such prefix, create @@ -147,7 +167,22 @@ void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { } else { bs = result; } - if (raxTryInsert(bs->clients,(unsigned char*)&c,sizeof(c),NULL,NULL)) { + + /* Find or create the per-user client vector. */ + vec *user_clients; + if (!raxFind(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + (void**)&user_clients)) + { + user_clients = zmalloc(sizeof(vec)); + vecInit(user_clients, NULL, 0); + raxInsert(bs->clients, + (unsigned char*)&c->user, sizeof(c->user), + user_clients, NULL); + } + + if (vecIndexOf(user_clients, c) < 0) { + vecPush(user_clients, c); if (c->client_tracking_prefixes == NULL) c->client_tracking_prefixes = raxNew(); raxInsert(c->client_tracking_prefixes, @@ -552,28 +587,30 @@ void trackingLimitUsedSlots(void) { } /* Generate Redis protocol for an array containing all the key names - * in the 'keys' radix tree. If the client is not NULL, the list will not - * include keys that were modified the last time by this client, in order - * to implement the NOLOOP option. + * in the 'keys' radix tree, filtered by ACL permissions of user 'u' and + * optionally by NOLOOP (skipping keys last modified by 'noloop_client'). + * + * If 'u' is non-NULL, keys the user is not permitted to observe are excluded. + * If 'c' is non-NULL, keys whose last modifier (ri.data) matches + * that client are excluded. * * If the resulting array would be empty, NULL is returned instead. */ -sds trackingBuildBroadcastReply(client *c, rax *keys) { +sds trackingBuildBroadcastReply(user *u, client *c, rax *keys) { + debugServerAssert(!c || c->flags & CLIENT_TRACKING_NOLOOP); raxIterator ri; - uint64_t count; + uint64_t count = 0; - if (c == NULL) { - count = raxSize(keys); - } else { - count = 0; - raxStart(&ri,keys); - raxSeek(&ri,"^",NULL,0); - while(raxNext(&ri)) { - if (ri.data != c) count++; - } - raxStop(&ri); - - if (count == 0) return NULL; + raxStart(&ri, keys); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + if (c && ri.data == c) continue; + if (u && ACLUserCheckKeyPerm(u, (const char*)ri.key, ri.key_len, + ACL_READ_PERMISSION) != ACL_OK) continue; + count++; } + raxStop(&ri); + + if (count == 0) return NULL; /* Create the array reply with the list of keys once, then send * it to all the clients subscribed to this prefix. */ @@ -588,6 +625,8 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) { raxSeek(&ri,"^",NULL,0); while(raxNext(&ri)) { if (c && ri.data == c) continue; + if (u && ACLUserCheckKeyPerm(u, (const char*)ri.key, ri.key_len, + ACL_READ_PERMISSION) != ACL_OK) continue; len = ll2string(buf,sizeof(buf),ri.key_len); proto = sdscatlen(proto,"$",1); proto = sdscatlen(proto,buf,len); @@ -599,11 +638,127 @@ sds trackingBuildBroadcastReply(client *c, rax *keys) { return proto; } +/* Send pending BCAST invalidation messages for a single prefix's + * bcastState, then reset bs->keys. Iterates user buckets, builds + * one proto per user, and sends to each client in the bucket. */ +static void trackingBcastInvalidationsForPrefix(bcastState *bs) { + if (raxSize(bs->keys) == 0) return; + + raxIterator ri; + raxStart(&ri, bs->clients); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + user *u; + memcpy(&u, ri.key, sizeof(u)); + vec *user_clients = ri.data; + + sds proto = trackingBuildBroadcastReply(u, NULL, bs->keys); + + for (size_t j = 0; j < vecSize(user_clients); j++) { + client *c = vecGet(user_clients, j); + + if (c->flags & CLIENT_TRACKING_NOLOOP) { + sds adhoc = trackingBuildBroadcastReply(u, c, bs->keys); + if (!adhoc) continue; + sendTrackingMessage(c, adhoc, + sdslen(adhoc), 1); + sdsfree(adhoc); + continue; + } + if (!proto) continue; + + sendTrackingMessage(c, proto, sdslen(proto), 1); + } + + sdsfree(proto); + } + raxStop(&ri); + + raxFree(bs->keys); + bs->keys = raxNew(); +} + +/* Send pending BCAST invalidation messages for every prefix in + * 'prefixes' (a rax of prefix -> NULL, i.e. client_tracking_prefixes). + * This triggers the full broadcast cycle for each matching prefix. */ +static void trackingBcastSendInvalidationsForPrefixes(rax *prefixes) { + raxIterator ri; + raxStart(&ri, prefixes); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + void *result; + int found = raxFind(PrefixTable, ri.key, ri.key_len, &result); + serverAssert(found); + trackingBcastInvalidationsForPrefix(result); + } + raxStop(&ri); +} + +/* Move client 'c' from its old user bucket (keyed by 'old_user') to + * the bucket for c->user in every bcastState the client subscribes to. + * Must be called AFTER c->user is updated. */ +static void trackingBcastMoveClient(client *c, user *old_user) { + user *new_user = c->user; + raxIterator ri; + raxStart(&ri, c->client_tracking_prefixes); + raxSeek(&ri, "^", NULL, 0); + while(raxNext(&ri)) { + void *result; + int found = raxFind(PrefixTable, ri.key, ri.key_len, &result); + serverAssert(found); + bcastState *bs = result; + + /* Swap-remove from old user vector. */ + vec *from_clients; + found = raxFind(bs->clients, + (unsigned char*)&old_user, sizeof(old_user), + (void**)&from_clients); + serverAssert(found); + ssize_t idx = vecIndexOf(from_clients, c); + serverAssert(idx >= 0); + vecSwapRemoveAt(from_clients, idx); + if (vecSize(from_clients) == 0) { + vecRelease(from_clients); + zfree(from_clients); + raxRemove(bs->clients, + (unsigned char*)&old_user, sizeof(old_user), NULL); + } + + /* Insert into new user vector. */ + vec *to_clients; + if (!raxFind(bs->clients, + (unsigned char*)&new_user, sizeof(new_user), + (void**)&to_clients)) + { + to_clients = zmalloc(sizeof(vec)); + vecInit(to_clients, NULL, 0); + raxInsert(bs->clients, + (unsigned char*)&new_user, sizeof(new_user), + to_clients, NULL); + } + vecPush(to_clients, c); + } + raxStop(&ri); +} + +/* Handle a BCAST tracking client after a user change: flush all pending + * invalidation messages for its prefixes (so every subscriber receives + * them under the previous ACL identity), then move the client from the + * 'old_user' bucket to the bucket for c->user. + * Must be called AFTER c->user is updated. */ +void trackingBroadcastPostUserSwitch(client *c, user *old_user) { + if (!(c->flags & CLIENT_TRACKING_BCAST)) return; + if (c->user == old_user) return; + + trackingBcastSendInvalidationsForPrefixes(c->client_tracking_prefixes); + trackingBcastMoveClient(c, old_user); +} + /* This function will run the prefixes of clients in BCAST mode and * keys that were modified about each prefix, and will send the * notifications to each client in each prefix. */ void trackingBroadcastInvalidationMessages(void) { - raxIterator ri, ri2; + raxIterator ri; /* Return ASAP if there is nothing to do here. */ if (TrackingTable == NULL || !server.tracking_clients) return; @@ -611,41 +766,8 @@ void trackingBroadcastInvalidationMessages(void) { raxStart(&ri,PrefixTable); raxSeek(&ri,"^",NULL,0); - /* For each prefix... */ while(raxNext(&ri)) { - bcastState *bs = ri.data; - - if (raxSize(bs->keys)) { - /* Generate the common protocol for all the clients that are - * not using the NOLOOP option. */ - sds proto = trackingBuildBroadcastReply(NULL,bs->keys); - - /* Send this array of keys to every client in the list. */ - raxStart(&ri2,bs->clients); - raxSeek(&ri2,"^",NULL,0); - while(raxNext(&ri2)) { - client *c; - memcpy(&c,ri2.key,sizeof(c)); - if (c->flags & CLIENT_TRACKING_NOLOOP) { - /* This client may have certain keys excluded. */ - sds adhoc = trackingBuildBroadcastReply(c,bs->keys); - if (adhoc) { - sendTrackingMessage(c,adhoc,sdslen(adhoc),1); - sdsfree(adhoc); - } - } else { - sendTrackingMessage(c,proto,sdslen(proto),1); - } - } - raxStop(&ri2); - - /* Clean up: we can remove everything from this state, because we - * want to only track the new keys that will be accumulated starting - * from now. */ - sdsfree(proto); - } - raxFree(bs->keys); - bs->keys = raxNew(); + trackingBcastInvalidationsForPrefix(ri.data); } raxStop(&ri); } diff --git a/src/vector.c b/src/vector.c index fc0ba13e1..70f080dbd 100644 --- a/src/vector.c +++ b/src/vector.c @@ -100,6 +100,21 @@ void vecPush(vec *v, void *value) { v->data[v->size++] = value; } +/* Return the index of the first occurrence of 'elem', or -1 if not found. */ +ssize_t vecIndexOf(const vec *v, void *elem) { + for (size_t i = 0; i < v->size; i++) { + if (v->data[i] == elem) return (ssize_t)i; + } + return -1; +} + +/* Remove the element at 'index' by swapping with the last element. + * Does not invoke the free callback. Requires index < vecSize(v). */ +void vecSwapRemoveAt(vec *v, size_t index) { + assert(index < v->size); + v->data[index] = v->data[--v->size]; +} + #ifdef REDIS_TEST #include @@ -221,6 +236,49 @@ int vectorTest(int argc, char **argv, int flags) vecRelease(&v); test_cond("vecRelease() free method is a no-op on empty vector", vecTestFreeCalls == 0); + /* vecIndexOf tests */ + vecInit(&v, NULL, 0); + test_cond("vecIndexOf() returns -1 on empty vector", + vecIndexOf(&v, &one) == -1); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + test_cond("vecIndexOf() finds first element", + vecIndexOf(&v, &one) == 0); + test_cond("vecIndexOf() finds middle element", + vecIndexOf(&v, &two) == 1); + test_cond("vecIndexOf() finds last element", + vecIndexOf(&v, &three) == 2); + test_cond("vecIndexOf() returns -1 for missing element", + vecIndexOf(&v, &four) == -1); + vecRelease(&v); + + /* vecSwapRemoveAt tests */ + vecInit(&v, NULL, 0); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + vecSwapRemoveAt(&v, 1); + test_cond("vecSwapRemoveAt() removes middle element and swaps with last", + vecSize(&v) == 2 && + vecGet(&v, 0) == &one && vecGet(&v, 1) == &three); + vecSwapRemoveAt(&v, 1); + test_cond("vecSwapRemoveAt() removes last element", + vecSize(&v) == 1 && vecGet(&v, 0) == &one); + vecSwapRemoveAt(&v, 0); + test_cond("vecSwapRemoveAt() removes sole element", + vecSize(&v) == 0); + vecRelease(&v); + + vecInit(&v, NULL, 0); + vecPush(&v, &one); + vecPush(&v, &two); + vecPush(&v, &three); + vecSwapRemoveAt(&v, 0); + test_cond("vecSwapRemoveAt() removes first element and swaps with last", + vecSize(&v) == 2 && + vecGet(&v, 0) == &three && vecGet(&v, 1) == &two); + vecRelease(&v); return 0; } diff --git a/src/vector.h b/src/vector.h index c89955c98..c89465567 100644 --- a/src/vector.h +++ b/src/vector.h @@ -2,6 +2,7 @@ #define REDIS_VECTOR_H #include +#include /* * Simple append-only vector (dynamic array) of void * elements. @@ -96,6 +97,13 @@ void vecReserve(vec *v, size_t mincap); /* Append one element, growing storage as needed. */ void vecPush(vec *v, void *value); +/* Return the index of the first occurrence of 'elem', or -1 if not found. */ +ssize_t vecIndexOf(const vec *v, void *elem); + +/* Remove the element at 'index' by swapping with the last element. + * Does not invoke the free callback. Requires index < vecSize(v). */ +void vecSwapRemoveAt(vec *v, size_t index); + #ifdef REDIS_TEST int vectorTest(int argc, char **argv, int flags); #endif diff --git a/tests/unit/tracking.tcl b/tests/unit/tracking.tcl index 174575eee..14d9daab2 100644 --- a/tests/unit/tracking.tcl +++ b/tests/unit/tracking.tcl @@ -883,6 +883,139 @@ start_server {tags {"tracking network logreqres:skip"}} { assert_equal {PONG} [$rd read] } + test {BCAST ACL filtering - two clients same user see only permitted keys} { + clean_all + + r ACL SETUSER shareduser on >pass123 ~public:* +@all + set c1 [redis_deferring_client] + set c2 [redis_deferring_client] + + $c1 AUTH shareduser pass123 + $c1 read + + $c2 AUTH shareduser pass123 + $c2 read + + $c1 HELLO 3 + $c1 read + $c2 HELLO 3 + $c2 read + + $c1 CLIENT TRACKING on BCAST PREFIX public: PREFIX admin: + assert_match {*OK*} [$c1 read] + $c2 CLIENT TRACKING on BCAST PREFIX public: PREFIX admin: + assert_match {*OK*} [$c2 read] + + $rd_sg MSET public:a{t} 1 admin:b{t} 2 + + # Both clients should receive exactly {public:a{t}} for the + # public: prefix, and nothing for admin: (filtered out by ACL). + set c1_keys {} + set c2_keys {} + # Read invalidation messages: there are two prefixes, but only + # public: should have data for shareduser. + after 100 + # $rd_sg is synchronous, so modified keys are already recorded + # on the server by the time we send PING. BCAST invalidations + # are flushed in beforeSleep before PONG, so they precede it + # on the wire. Drain all push messages until we hit the PONG. + $c1 PING + while 1 { + set resp [$c1 read] + if {[lindex $resp 0] eq "invalidate"} { + lappend c1_keys {*}[lindex $resp 1] + } else { + break + } + } + $c2 PING + while 1 { + set resp [$c2 read] + if {[lindex $resp 0] eq "invalidate"} { + lappend c2_keys {*}[lindex $resp 1] + } else { + break + } + } + + assert_equal [lsort $c1_keys] [list public:a{t}] + assert_equal [lsort $c2_keys] [list public:a{t}] + + $c1 CLIENT TRACKING off + $c1 read + $c2 CLIENT TRACKING off + $c2 read + $c1 close + $c2 close + r ACL DELUSER shareduser + } + + test {BCAST re-AUTH re-buckets correctly with ACL filtering} { + clean_all + + r ACL SETUSER usr_a on >passA ~a:* +@all + r ACL SETUSER usr_b on >passB ~b:* +@all + + set tc [redis_deferring_client] + $tc AUTH usr_a passA + $tc read + + $tc HELLO 3 + $tc read + + $tc CLIENT TRACKING on BCAST PREFIX a: PREFIX b: + assert_match {*OK*} [$tc read] + + # Write keys matching both prefixes. + $rd_sg SET a:1{t} val1 + $rd_sg SET b:1{t} val1 + + # Under usr_a, only a:* is visible. + # $rd_sg is synchronous, so modified keys are already recorded + # on the server by the time we send PING. BCAST invalidations + # are flushed in beforeSleep before PONG, so they precede it + # on the wire. Drain all push messages until we hit the PONG. + after 100 + $tc PING + set keys {} + while 1 { + set resp [$tc read] + if {[lindex $resp 0] eq "invalidate"} { + lappend keys {*}[lindex $resp 1] + } else { + break + } + } + assert_equal $keys [list a:1{t}] + + # Re-AUTH as usr_b. + $tc AUTH usr_b passB + $tc read + + # Write again. + $rd_sg SET a:2{t} val2 + $rd_sg SET b:2{t} val2 + + after 100 + $tc PING + set keys {} + while 1 { + set resp [$tc read] + if {[lindex $resp 0] eq "invalidate"} { + lappend keys {*}[lindex $resp 1] + } else { + break + } + } + assert_equal $keys [list b:2{t}] + + $tc CLIENT TRACKING off + $tc read + $tc close + r ACL DELUSER usr_a + r ACL DELUSER usr_b + } + $rd_redirection close $rd_sg close $rd close