package org.keycloak.keys;

import java.security.PublicKey;
import java.security.cert.Certificate;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.crypto.SecretKey;
import org.jboss.logging.Logger;
import org.keycloak.component.ComponentModel;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.models.KeyManager;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;

/* loaded from: input_file:org/keycloak/keys/DefaultKeyManager.class */
public class DefaultKeyManager implements KeyManager {
    private static final Logger logger = Logger.getLogger(DefaultKeyManager.class);
    private final KeycloakSession session;
    private final Map<String, List<KeyProvider>> providersMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/keycloak/keys/DefaultKeyManager$ProviderComparator.class */
    public static class ProviderComparator implements Comparator<ComponentModel> {
        private ProviderComparator() {
        }

        @Override // java.util.Comparator
        public int compare(ComponentModel componentModel, ComponentModel componentModel2) {
            int compare = Long.compare(componentModel2.get(Attributes.PRIORITY_KEY, 0L), componentModel.get(Attributes.PRIORITY_KEY, 0L));
            return compare != 0 ? compare : componentModel.getId().compareTo(componentModel2.getId());
        }
    }

    public DefaultKeyManager(KeycloakSession keycloakSession) {
        this.session = keycloakSession;
    }

    public KeyWrapper getActiveKey(RealmModel realmModel, KeyUse keyUse, String str) {
        KeyWrapper activeKey = getActiveKey(getProviders(realmModel), realmModel, keyUse, str);
        if (activeKey != null) {
            return activeKey;
        }
        logger.debugv("Failed to find active key for realm, trying fallback: realm={0} algorithm={1} use={2}", realmModel.getName(), str, keyUse.name());
        Stream providerFactoriesStream = this.session.getKeycloakSessionFactory().getProviderFactoriesStream(KeyProvider.class);
        Class<KeyProviderFactory> cls = KeyProviderFactory.class;
        Objects.requireNonNull(KeyProviderFactory.class);
        if (providerFactoriesStream.map((v1) -> {
            return r1.cast(v1);
        }).filter(keyProviderFactory -> {
            return keyProviderFactory.createFallbackKeys(this.session, keyUse, str);
        }).findFirst().isPresent()) {
            this.providersMap.remove(realmModel.getId());
            KeyWrapper activeKey2 = getActiveKey(getProviders(realmModel), realmModel, keyUse, str);
            if (activeKey2 != null) {
                logger.infov("No keys found for realm={0} and algorithm={1} for use={2}. Generating keys.", realmModel.getName(), str, keyUse.name());
                return activeKey2;
            }
        }
        logger.errorv("Failed to create fallback key for realm: realm={0} algorithm={1} use={2}", realmModel.getName(), str, keyUse.name());
        throw new RuntimeException("Failed to find key: realm=" + realmModel.getName() + " algorithm=" + str + " use=" + keyUse.name());
    }

    private KeyWrapper getActiveKey(List<KeyProvider> list, RealmModel realmModel, KeyUse keyUse, String str) {
        Consumer consumer = keyWrapper -> {
            if (logger.isTraceEnabled()) {
                logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realmModel.getName(), keyWrapper.getKid(), str, keyUse.name()});
            }
        };
        Iterator<KeyProvider> it = list.iterator();
        while (it.hasNext()) {
            Optional findFirst = it.next().getKeysStream().filter(keyWrapper2 -> {
                return keyWrapper2.getStatus().isActive() && matches(keyWrapper2, keyUse, str);
            }).peek(consumer).findFirst();
            if (findFirst.isPresent()) {
                return (KeyWrapper) findFirst.get();
            }
        }
        return null;
    }

    public KeyWrapper getKey(RealmModel realmModel, String str, KeyUse keyUse, String str2) {
        if (str == null) {
            logger.warnv("kid is null, can't find public key: realm={0}", realmModel.getName());
            return null;
        }
        Consumer consumer = keyWrapper -> {
            if (logger.isTraceEnabled()) {
                logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realmModel.getName(), keyWrapper.getKid(), str2, keyUse.name()});
            }
        };
        Iterator<KeyProvider> it = getProviders(realmModel).iterator();
        while (it.hasNext()) {
            Optional findFirst = it.next().getKeysStream().filter(keyWrapper2 -> {
                return Objects.equals(keyWrapper2.getKid(), str) && keyWrapper2.getStatus().isEnabled() && matches(keyWrapper2, keyUse, str2);
            }).peek(consumer).findFirst();
            if (findFirst.isPresent()) {
                return (KeyWrapper) findFirst.get();
            }
        }
        if (!logger.isTraceEnabled()) {
            return null;
        }
        logger.tracev("Failed to find public key: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realmModel.getName(), str, str2, keyUse.name()});
        return null;
    }

    public Stream<KeyWrapper> getKeysStream(RealmModel realmModel, KeyUse keyUse, String str) {
        return getProviders(realmModel).stream().flatMap(keyProvider -> {
            return keyProvider.getKeysStream().filter(keyWrapper -> {
                return keyWrapper.getStatus().isEnabled() && matches(keyWrapper, keyUse, str);
            });
        });
    }

    public Stream<KeyWrapper> getKeysStream(RealmModel realmModel) {
        return getProviders(realmModel).stream().flatMap((v0) -> {
            return v0.getKeysStream();
        });
    }

    @Deprecated
    public KeyManager.ActiveRsaKey getActiveRsaKey(RealmModel realmModel) {
        return new KeyManager.ActiveRsaKey(getActiveKey(realmModel, KeyUse.SIG, "RS256"));
    }

    @Deprecated
    public KeyManager.ActiveHmacKey getActiveHmacKey(RealmModel realmModel) {
        KeyWrapper activeKey = getActiveKey(realmModel, KeyUse.SIG, "HS256");
        return new KeyManager.ActiveHmacKey(activeKey.getKid(), activeKey.getSecretKey());
    }

    @Deprecated
    public KeyManager.ActiveAesKey getActiveAesKey(RealmModel realmModel) {
        KeyWrapper activeKey = getActiveKey(realmModel, KeyUse.ENC, "AES");
        return new KeyManager.ActiveAesKey(activeKey.getKid(), activeKey.getSecretKey());
    }

    @Deprecated
    public PublicKey getRsaPublicKey(RealmModel realmModel, String str) {
        KeyWrapper key = getKey(realmModel, str, KeyUse.SIG, "RS256");
        if (key != null) {
            return (PublicKey) key.getPublicKey();
        }
        return null;
    }

    @Deprecated
    public Certificate getRsaCertificate(RealmModel realmModel, String str) {
        KeyWrapper key = getKey(realmModel, str, KeyUse.SIG, "RS256");
        if (key != null) {
            return key.getCertificate();
        }
        return null;
    }

    @Deprecated
    public SecretKey getHmacSecretKey(RealmModel realmModel, String str) {
        KeyWrapper key = getKey(realmModel, str, KeyUse.SIG, "HS256");
        if (key != null) {
            return key.getSecretKey();
        }
        return null;
    }

    @Deprecated
    public SecretKey getAesSecretKey(RealmModel realmModel, String str) {
        return getKey(realmModel, str, KeyUse.ENC, "AES").getSecretKey();
    }

    @Deprecated
    public List<RsaKeyMetadata> getRsaKeys(RealmModel realmModel) {
        return (List) getKeysStream(realmModel, KeyUse.SIG, "RS256").map(keyWrapper -> {
            RsaKeyMetadata rsaKeyMetadata = new RsaKeyMetadata();
            rsaKeyMetadata.setCertificate(keyWrapper.getCertificate());
            rsaKeyMetadata.setPublicKey((PublicKey) keyWrapper.getPublicKey());
            rsaKeyMetadata.setKid(keyWrapper.getKid());
            rsaKeyMetadata.setProviderId(keyWrapper.getProviderId());
            rsaKeyMetadata.setProviderPriority(keyWrapper.getProviderPriority());
            rsaKeyMetadata.setStatus(keyWrapper.getStatus());
            return rsaKeyMetadata;
        }).collect(Collectors.toList());
    }

    public List<SecretKeyMetadata> getHmacKeys(RealmModel realmModel) {
        return (List) getKeysStream(realmModel, KeyUse.SIG, "HS256").map(keyWrapper -> {
            SecretKeyMetadata secretKeyMetadata = new SecretKeyMetadata();
            secretKeyMetadata.setKid(keyWrapper.getKid());
            secretKeyMetadata.setProviderId(keyWrapper.getProviderId());
            secretKeyMetadata.setProviderPriority(keyWrapper.getProviderPriority());
            secretKeyMetadata.setStatus(keyWrapper.getStatus());
            return secretKeyMetadata;
        }).collect(Collectors.toList());
    }

    public List<SecretKeyMetadata> getAesKeys(RealmModel realmModel) {
        return (List) getKeysStream(realmModel, KeyUse.ENC, "AES").map(keyWrapper -> {
            SecretKeyMetadata secretKeyMetadata = new SecretKeyMetadata();
            secretKeyMetadata.setKid(keyWrapper.getKid());
            secretKeyMetadata.setProviderId(keyWrapper.getProviderId());
            secretKeyMetadata.setProviderPriority(keyWrapper.getProviderPriority());
            secretKeyMetadata.setStatus(keyWrapper.getStatus());
            return secretKeyMetadata;
        }).collect(Collectors.toList());
    }

    private boolean matches(KeyWrapper keyWrapper, KeyUse keyUse, String str) {
        return keyUse.equals(keyWrapper.getUse()) && keyWrapper.getAlgorithmOrDefault().equals(str);
    }

    private List<KeyProvider> getProviders(RealmModel realmModel) {
        List<KeyProvider> list = this.providersMap.get(realmModel.getId());
        if (list == null) {
            list = (List) realmModel.getComponentsStream(realmModel.getId(), KeyProvider.class.getName()).sorted(new ProviderComparator()).map(componentModel -> {
                try {
                    KeyProvider create = this.session.getKeycloakSessionFactory().getProviderFactory(KeyProvider.class, componentModel.getProviderId()).create(this.session, componentModel);
                    this.session.enlistForClose(create);
                    return create;
                } catch (Throwable th) {
                    logger.errorv(th, "Failed to load provider {0}", componentModel.getId());
                    return null;
                }
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            }).collect(Collectors.toList());
            this.providersMap.put(realmModel.getId(), list);
        }
        return list;
    }
}
