diff --git a/core/src/main/java/org/keycloak/crypto/PublicKeysWrapper.java b/core/src/main/java/org/keycloak/crypto/PublicKeysWrapper.java
new file mode 100644
index 00000000000..d2cab90c5d1
--- /dev/null
+++ b/core/src/main/java/org/keycloak/crypto/PublicKeysWrapper.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2022 Red Hat, Inc. and/or its affiliates
+ * and other contributors as indicated by the @author tags.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package org.keycloak.crypto;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * @author Marek Posolda
+ */
+public class PublicKeysWrapper {
+
+ private final List keys;
+
+ public static final PublicKeysWrapper EMPTY = new PublicKeysWrapper(Collections.emptyList());
+
+ public PublicKeysWrapper(List keys) {
+ this.keys = keys;
+ }
+
+ public List getKeys() {
+ return keys;
+ }
+
+ public List getKids() {
+ return keys.stream()
+ .map(KeyWrapper::getKid)
+ .collect(Collectors.toList());
+ }
+
+ public KeyWrapper getKeyByKidAndAlg(String kid, String alg) {
+ return keys.stream()
+ .filter(keyWrapper -> kid == null || kid.equals(keyWrapper.getKid()))
+ .filter(keyWrapper -> alg == null || alg.equals(keyWrapper.getAlgorithmOrDefault()) || (keyWrapper.getAlgorithm() == null && kid != null))
+ .findFirst().orElse(null);
+ }
+}
diff --git a/core/src/main/java/org/keycloak/util/JWKSUtils.java b/core/src/main/java/org/keycloak/util/JWKSUtils.java
index fc352bae4ac..60b9ea83011 100644
--- a/core/src/main/java/org/keycloak/util/JWKSUtils.java
+++ b/core/src/main/java/org/keycloak/util/JWKSUtils.java
@@ -17,17 +17,19 @@
package org.keycloak.util;
+import org.jboss.logging.Logger;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
+import org.keycloak.crypto.PublicKeysWrapper;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.jose.jwk.JWKParser;
import java.security.PublicKey;
-import java.util.HashMap;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
-import java.util.logging.Level;
-import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* @author Marek Posolda
@@ -36,27 +38,22 @@ public class JWKSUtils {
private static final Logger logger = Logger.getLogger(JWKSUtils.class.getName());
+ /**
+ * @deprecated Use {@link #getKeyWrappersForUse(JSONWebKeySet, JWK.Use)}
+ **/
+ @Deprecated
public static Map getKeysForUse(JSONWebKeySet keySet, JWK.Use requestedUse) {
- Map result = new HashMap<>();
-
- for (JWK jwk : keySet.getKeys()) {
- JWKParser parser = JWKParser.create(jwk);
- if (jwk.getPublicKeyUse() == null) {
- logger.log(Level.FINE, "Ignoring JWK key '%s'. Missing required field 'use'.", jwk.getKeyId());
- } else if (requestedUse.asString().equals(jwk.getPublicKeyUse()) && parser.isKeyTypeSupported(jwk.getKeyType())) {
- result.put(jwk.getKeyId(), parser.toPublicKey());
- }
- }
-
- return result;
+ return getKeyWrappersForUse(keySet, requestedUse).getKeys()
+ .stream()
+ .collect(Collectors.toMap(KeyWrapper::getKid, keyWrapper -> (PublicKey) keyWrapper.getPublicKey()));
}
- public static Map getKeyWrappersForUse(JSONWebKeySet keySet, JWK.Use requestedUse) {
- Map result = new HashMap<>();
+ public static PublicKeysWrapper getKeyWrappersForUse(JSONWebKeySet keySet, JWK.Use requestedUse) {
+ List result = new ArrayList<>();
for (JWK jwk : keySet.getKeys()) {
JWKParser parser = JWKParser.create(jwk);
if (jwk.getPublicKeyUse() == null) {
- logger.log(Level.FINE, "Ignoring JWK key '%s'. Missing required field 'use'.", jwk.getKeyId());
+ logger.debugf("Ignoring JWK key '%s'. Missing required field 'use'.", jwk.getKeyId());
} else if (requestedUse.asString().equals(jwk.getPublicKeyUse()) && parser.isKeyTypeSupported(jwk.getKeyType())) {
KeyWrapper keyWrapper = new KeyWrapper();
keyWrapper.setKid(jwk.getKeyId());
@@ -66,10 +63,10 @@ public class JWKSUtils {
keyWrapper.setType(jwk.getKeyType());
keyWrapper.setUse(getKeyUse(jwk.getPublicKeyUse()));
keyWrapper.setPublicKey(parser.toPublicKey());
- result.put(keyWrapper.getKid(), keyWrapper);
+ result.add(keyWrapper);
}
}
- return result;
+ return new PublicKeysWrapper(result);
}
private static KeyUse getKeyUse(String keyUse) {
@@ -87,7 +84,7 @@ public class JWKSUtils {
for (JWK jwk : keySet.getKeys()) {
JWKParser parser = JWKParser.create(jwk);
if (jwk.getPublicKeyUse() == null) {
- logger.log(Level.FINE, "Ignoring JWK key '%s'. Missing required field 'use'.", jwk.getKeyId());
+ logger.debugf("Ignoring JWK key '%s'. Missing required field 'use'.", jwk.getKeyId());
} else if (requestedUse.asString().equals(parser.getJwk().getPublicKeyUse()) && parser.isKeyTypeSupported(jwk.getKeyType())) {
return jwk;
}
diff --git a/core/src/test/java/org/keycloak/util/JWKSUtilsTest.java b/core/src/test/java/org/keycloak/util/JWKSUtilsTest.java
index 86785251a5e..983a391accd 100644
--- a/core/src/test/java/org/keycloak/util/JWKSUtilsTest.java
+++ b/core/src/test/java/org/keycloak/util/JWKSUtilsTest.java
@@ -21,12 +21,14 @@ import org.junit.ClassRule;
import org.junit.Test;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
-import org.keycloak.jose.jwk.*;
+import org.keycloak.crypto.PublicKeysWrapper;
+import org.keycloak.jose.jwk.JSONWebKeySet;
+import org.keycloak.jose.jwk.JWK;
import org.keycloak.rule.CryptoInitRule;
-import java.util.Map;
-
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
public abstract class JWKSUtilsTest {
@@ -61,6 +63,14 @@ public abstract class JWKSUtilsTest {
" }," +
" {" +
" \"kty\": \"RSA\"," +
+ " \"alg\": \"RS512\"," +
+ " \"use\": \"sig\"," +
+ " \"kid\": \"" + kidRsa1 + "\"," +
+ " \"n\": \"soFDjoZ5mQ8XAA7reQAFg90inKAHk0DXMTizo4JuOsgzUbhcplIeZ7ks83hsEjm8mP8lUVaHMPMAHEIp3gu6Xxsg-s73ofx1dtt_Fo7aj8j383MFQGl8-FvixTVobNeGeC0XBBQjN8lEl-lIwOa4ZoERNAShplTej0ntDp7TQm0=\"," +
+ " \"e\": \"AQAB\"" +
+ " }," +
+ " {" +
+ " \"kty\": \"RSA\"," +
" \"kid\": \"" + kidInvalidKey + "\"," +
" \"n\": \"soFDjoZ5mQ8XAA7reQAFg90inKAHk0DXMTizo4JuOsgzUbhcplIeZ7ks83hsEjm8mP8lUVaHMPMAHEIp3gu6Xxsg-s73ofx1dtt_Fo7aj8j383MFQGl8-FvixTVobNeGeC0XBBQjN8lEl-lIwOa4ZoERNAShplTej0ntDp7TQm0=\"," +
" \"e\": \"AQAB\"" +
@@ -84,36 +94,61 @@ public abstract class JWKSUtilsTest {
" }" +
"] }";
JSONWebKeySet jsonWebKeySet = JsonSerialization.readValue(jwksJson, JSONWebKeySet.class);
- Map keyWrappersForUse = JWKSUtils.getKeyWrappersForUse(jsonWebKeySet, JWK.Use.SIG);
- assertEquals(4, keyWrappersForUse.size());
+ PublicKeysWrapper keyWrappersForUse = JWKSUtils.getKeyWrappersForUse(jsonWebKeySet, JWK.Use.SIG);
+ assertEquals(5, keyWrappersForUse.getKeys().size());
- KeyWrapper key = keyWrappersForUse.get(kidRsa1);
+ // get by both kid and alg
+ KeyWrapper key = keyWrappersForUse.getKeyByKidAndAlg(kidRsa1, "RS256");
assertNotNull(key);
assertEquals("RS256", key.getAlgorithmOrDefault());
assertEquals(KeyUse.SIG, key.getUse());
assertEquals(kidRsa1, key.getKid());
assertEquals("RSA", key.getType());
- key = keyWrappersForUse.get(kidRsa2);
+ // get by both kid and alg with RS512. It is same 'kid' as the previous, but should choose "RS512" key now
+ key = keyWrappersForUse.getKeyByKidAndAlg(kidRsa1, "RS512");
+ assertNotNull(key);
+ assertEquals("RS512", key.getAlgorithmOrDefault());
+ assertEquals(KeyUse.SIG, key.getUse());
+ assertEquals(kidRsa1, key.getKid());
+ assertEquals("RSA", key.getType());
+
+ // Get by kid only. Should choose default algorithm, so RS256
+ key = keyWrappersForUse.getKeyByKidAndAlg(kidRsa1, null);
+ assertNotNull(key);
+ assertEquals("RS256", key.getAlgorithmOrDefault());
+ assertEquals(KeyUse.SIG, key.getUse());
+ assertEquals(kidRsa1, key.getKid());
+ assertEquals("RSA", key.getType());
+
+ key = keyWrappersForUse.getKeyByKidAndAlg(kidRsa2, null);
assertNotNull(key);
assertEquals("RS256", key.getAlgorithmOrDefault());
assertEquals(KeyUse.SIG, key.getUse());
assertEquals(kidRsa2, key.getKid());
assertEquals("RSA", key.getType());
- key = keyWrappersForUse.get(kidEC1);
+ key = keyWrappersForUse.getKeyByKidAndAlg(kidEC1, null);
assertNotNull(key);
assertEquals("ES384", key.getAlgorithmOrDefault());
assertEquals(KeyUse.SIG, key.getUse());
assertEquals(kidEC1, key.getKid());
assertEquals("EC", key.getType());
- key = keyWrappersForUse.get(kidEC2);
+ key = keyWrappersForUse.getKeyByKidAndAlg(kidEC2, null);
assertNotNull(key);
assertNull(key.getAlgorithmOrDefault());
assertEquals(KeyUse.SIG, key.getUse());
assertEquals(kidEC2, key.getKid());
assertEquals("EC", key.getType());
+
+ // Search by alg only
+ key = keyWrappersForUse.getKeyByKidAndAlg(null, "ES384");
+ assertNotNull(key);
+ assertEquals("ES384", key.getAlgorithmOrDefault());
+ assertEquals(KeyUse.SIG, key.getUse());
+ assertEquals(kidEC1, key.getKid());
+ assertEquals("EC", key.getType());
}
diff --git a/model/infinispan/src/main/java/org/keycloak/keys/infinispan/InfinispanPublicKeyStorageProvider.java b/model/infinispan/src/main/java/org/keycloak/keys/infinispan/InfinispanPublicKeyStorageProvider.java
index d51e648ad02..6311dc58553 100644
--- a/model/infinispan/src/main/java/org/keycloak/keys/infinispan/InfinispanPublicKeyStorageProvider.java
+++ b/model/infinispan/src/main/java/org/keycloak/keys/infinispan/InfinispanPublicKeyStorageProvider.java
@@ -19,6 +19,7 @@ package org.keycloak.keys.infinispan;
import java.util.Collections;
import java.util.HashSet;
+import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
@@ -30,6 +31,7 @@ import org.jboss.logging.Logger;
import org.keycloak.cluster.ClusterProvider;
import org.keycloak.common.util.Time;
import org.keycloak.crypto.KeyWrapper;
+import org.keycloak.crypto.PublicKeysWrapper;
import org.keycloak.keys.PublicKeyLoader;
import org.keycloak.keys.PublicKeyStorageProvider;
import org.keycloak.models.KeycloakSession;
@@ -115,22 +117,17 @@ public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvi
}
}
-
- @Override
- public KeyWrapper getPublicKey(String modelKey, String kid, PublicKeyLoader loader) {
- return getPublicKey(modelKey, kid, null, loader);
- }
-
@Override
public KeyWrapper getFirstPublicKey(String modelKey, String algorithm, PublicKeyLoader loader) {
return getPublicKey(modelKey, null, algorithm, loader);
}
- private KeyWrapper getPublicKey(String modelKey, String kid, String algorithm, PublicKeyLoader loader) {
+ @Override
+ public KeyWrapper getPublicKey(String modelKey, String kid, String algorithm, PublicKeyLoader loader) {
// Check if key is in cache
PublicKeysEntry entry = keys.get(modelKey);
if (entry != null) {
- KeyWrapper publicKey = algorithm != null ? getPublicKeyByAlg(entry.getCurrentKeys(), algorithm) : getPublicKey(entry.getCurrentKeys(), kid);
+ KeyWrapper publicKey = entry.getCurrentKeys().getKeyByKidAndAlg(kid, algorithm);
if (publicKey != null) {
// return a copy of the key to not modify the cached one
return publicKey.cloneKey();
@@ -157,7 +154,7 @@ public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvi
entry = task.get();
// Computation finished. Let's see if key is available
- KeyWrapper publicKey = algorithm != null ? getPublicKeyByAlg(entry.getCurrentKeys(), algorithm) : getPublicKey(entry.getCurrentKeys(), kid);
+ KeyWrapper publicKey = entry.getCurrentKeys().getKeyByKidAndAlg(kid, algorithm);
if (publicKey != null) {
// return a copy of the key to not modify the cached one
return publicKey.cloneKey();
@@ -177,28 +174,12 @@ public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvi
log.warnf("Won't load the keys for model '%s' . Last request time was %d", modelKey, lastRequestTime);
}
- Set availableKids = entry==null ? Collections.emptySet() : entry.getCurrentKeys().keySet();
+ List availableKids = entry==null ? Collections.emptyList() : entry.getCurrentKeys().getKids();
log.warnf("PublicKey wasn't found in the storage. Requested kid: '%s' . Available kids: '%s'", kid, availableKids);
return null;
}
- private KeyWrapper getPublicKey(Map publicKeys, String kid) {
- // Backwards compatibility
- if (kid == null && !publicKeys.isEmpty()) {
- return publicKeys.values().iterator().next();
- } else {
- return publicKeys.get(kid);
- }
- }
-
- private KeyWrapper getPublicKeyByAlg(Map publicKeys, String algorithm) {
- if (algorithm == null) return null;
- for(KeyWrapper keyWrapper : publicKeys.values())
- if (algorithm.equals(keyWrapper.getAlgorithmOrDefault())) return keyWrapper;
- return null;
- }
-
@Override
public void close() {
@@ -224,10 +205,10 @@ public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvi
// Check again if we are allowed to send request. There is a chance other task was already finished and removed from tasksInProgress in the meantime.
if (currentTime > lastRequestTime + minTimeBetweenRequests) {
- Map publicKeys = delegate.loadKeys();
+ PublicKeysWrapper publicKeys = delegate.loadKeys();
if (log.isDebugEnabled()) {
- log.debugf("Public keys retrieved successfully for model %s. New kids: %s", modelKey, publicKeys.keySet().toString());
+ log.debugf("Public keys retrieved successfully for model %s. New kids: %s", modelKey, publicKeys.getKids());
}
entry = new PublicKeysEntry(currentTime, publicKeys);
diff --git a/model/infinispan/src/main/java/org/keycloak/keys/infinispan/PublicKeysEntry.java b/model/infinispan/src/main/java/org/keycloak/keys/infinispan/PublicKeysEntry.java
index 2f2d807de74..a26f4b46d11 100644
--- a/model/infinispan/src/main/java/org/keycloak/keys/infinispan/PublicKeysEntry.java
+++ b/model/infinispan/src/main/java/org/keycloak/keys/infinispan/PublicKeysEntry.java
@@ -18,9 +18,7 @@
package org.keycloak.keys.infinispan;
import java.io.Serializable;
-import java.util.Map;
-
-import org.keycloak.crypto.KeyWrapper;
+import org.keycloak.crypto.PublicKeysWrapper;
/**
* @author Marek Posolda
@@ -29,9 +27,9 @@ public class PublicKeysEntry implements Serializable {
private final int lastRequestTime;
- private final Map currentKeys;
+ private final PublicKeysWrapper currentKeys;
- public PublicKeysEntry(int lastRequestTime, Map currentKeys) {
+ public PublicKeysEntry(int lastRequestTime, PublicKeysWrapper currentKeys) {
this.lastRequestTime = lastRequestTime;
this.currentKeys = currentKeys;
}
@@ -40,7 +38,7 @@ public class PublicKeysEntry implements Serializable {
return lastRequestTime;
}
- public Map getCurrentKeys() {
+ public PublicKeysWrapper getCurrentKeys() {
return currentKeys;
}
}
diff --git a/model/infinispan/src/test/java/org/keycloak/keys/infinispan/InfinispanKeyStorageProviderTest.java b/model/infinispan/src/test/java/org/keycloak/keys/infinispan/InfinispanKeyStorageProviderTest.java
index 4c0bca3b81a..efdecc8f80c 100644
--- a/model/infinispan/src/test/java/org/keycloak/keys/infinispan/InfinispanKeyStorageProviderTest.java
+++ b/model/infinispan/src/test/java/org/keycloak/keys/infinispan/InfinispanKeyStorageProviderTest.java
@@ -17,7 +17,6 @@
package org.keycloak.keys.infinispan;
-import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -38,7 +37,7 @@ import org.junit.Before;
import org.junit.Test;
import org.keycloak.common.util.Time;
import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
-import org.keycloak.crypto.KeyWrapper;
+import org.keycloak.crypto.PublicKeysWrapper;
import org.keycloak.keys.PublicKeyLoader;
/**
@@ -129,7 +128,7 @@ public class InfinispanKeyStorageProviderTest {
@Override
public void run() {
InfinispanPublicKeyStorageProvider provider = new InfinispanPublicKeyStorageProvider(null, keys, tasksInProgress, minTimeBetweenRequests);
- provider.getPublicKey(modelKey, "kid1", new SampleLoader(modelKey));
+ provider.getPublicKey(modelKey, "kid1", null, new SampleLoader(modelKey));
}
}
@@ -144,12 +143,12 @@ public class InfinispanKeyStorageProviderTest {
}
@Override
- public Map loadKeys() throws Exception {
+ public PublicKeysWrapper loadKeys() throws Exception {
counters.putIfAbsent(modelKey, new AtomicInteger(0));
AtomicInteger currentCounter = counters.get(modelKey);
currentCounter.incrementAndGet();
- return Collections.emptyMap();
+ return PublicKeysWrapper.EMPTY;
}
}
diff --git a/model/map/src/main/java/org/keycloak/models/map/keys/MapPublicKeyStorageProvider.java b/model/map/src/main/java/org/keycloak/models/map/keys/MapPublicKeyStorageProvider.java
index 2926ab3b725..9f229f31f93 100644
--- a/model/map/src/main/java/org/keycloak/models/map/keys/MapPublicKeyStorageProvider.java
+++ b/model/map/src/main/java/org/keycloak/models/map/keys/MapPublicKeyStorageProvider.java
@@ -19,13 +19,14 @@ package org.keycloak.models.map.keys;
import org.jboss.logging.Logger;
import org.keycloak.crypto.KeyWrapper;
+import org.keycloak.crypto.PublicKeysWrapper;
import org.keycloak.keys.PublicKeyLoader;
import org.keycloak.keys.PublicKeyStorageProvider;
import org.keycloak.models.KeycloakSession;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
-import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
@@ -36,28 +37,24 @@ public class MapPublicKeyStorageProvider implements PublicKeyStorageProvider {
private final KeycloakSession session;
- private final Map>> tasksInProgress;
+ private final Map> tasksInProgress;
- public MapPublicKeyStorageProvider(KeycloakSession session, Map>> tasksInProgress) {
+ public MapPublicKeyStorageProvider(KeycloakSession session, Map> tasksInProgress) {
this.session = session;
this.tasksInProgress = tasksInProgress;
}
- @Override
- public KeyWrapper getPublicKey(String modelKey, String kid, PublicKeyLoader loader) {
- return getPublicKey(modelKey, kid, null, loader);
- }
-
@Override
public KeyWrapper getFirstPublicKey(String modelKey, String algorithm, PublicKeyLoader loader) {
return getPublicKey(modelKey, null, algorithm, loader);
}
- private KeyWrapper getPublicKey(String modelKey, String kid, String algorithm, PublicKeyLoader loader) {
+ @Override
+ public KeyWrapper getPublicKey(String modelKey, String kid, String algorithm, PublicKeyLoader loader) {
WrapperCallable wrapperCallable = new WrapperCallable(modelKey, loader);
- FutureTask