/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.keys;

import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.crypto.SecretKey;
import org.jboss.logging.Logger;
import org.keycloak.component.ComponentModel;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.keys.KeyProvider;
import org.keycloak.keys.KeyProviderFactory;
import org.keycloak.keys.RsaKeyMetadata;
import org.keycloak.keys.SecretKeyMetadata;
import org.keycloak.models.KeyManager;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderFactory;

public class DefaultKeyManager
implements KeyManager {
    private static final Logger logger = Logger.getLogger(DefaultKeyManager.class);
    private final KeycloakSession session;
    private final Map<String, List<KeyProvider>> providersMap = new HashMap<String, List<KeyProvider>>();

    public DefaultKeyManager(KeycloakSession session) {
        this.session = session;
    }

    public KeyWrapper getActiveKey(RealmModel realm, KeyUse use, String algorithm) {
        KeyWrapper activeKey = this.getActiveKey(this.getProviders(realm), realm, use, algorithm);
        if (activeKey != null) {
            return activeKey;
        }
        logger.debugv("Failed to find active key for realm, trying fallback: realm={0} algorithm={1} use={2}", (Object)realm.getName(), (Object)algorithm, (Object)use.name());
        for (ProviderFactory f : this.session.getKeycloakSessionFactory().getProviderFactories(KeyProvider.class)) {
            KeyProviderFactory kf = (KeyProviderFactory)f;
            if (!kf.createFallbackKeys(this.session, use, algorithm)) continue;
            this.providersMap.remove(realm.getId());
            List<KeyProvider> providers = this.getProviders(realm);
            activeKey = this.getActiveKey(providers, realm, use, algorithm);
            if (activeKey == null) break;
            logger.warnv("Fallback key created: realm={0} algorithm={1} use={2}", (Object)realm.getName(), (Object)algorithm, (Object)use.name());
            return activeKey;
        }
        logger.errorv("Failed to create fallback key for realm: realm={0} algorithm={1} use={2", (Object)realm.getName(), (Object)algorithm, (Object)use.name());
        throw new RuntimeException("Failed to find key: realm=" + realm.getName() + " algorithm=" + algorithm + " use=" + use.name());
    }

    private KeyWrapper getActiveKey(List<KeyProvider> providers, RealmModel realm, KeyUse use, String algorithm) {
        for (KeyProvider p : providers) {
            for (KeyWrapper key : p.getKeys()) {
                if (!key.getStatus().isActive() || !this.matches(key, use, algorithm)) continue;
                if (logger.isTraceEnabled()) {
                    logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realm.getName(), key.getKid(), algorithm, use.name()});
                }
                return key;
            }
        }
        return null;
    }

    public KeyWrapper getKey(RealmModel realm, String kid, KeyUse use, String algorithm) {
        if (kid == null) {
            logger.warnv("kid is null, can't find public key", (Object)realm.getName(), (Object)kid);
            return null;
        }
        for (KeyProvider p : this.getProviders(realm)) {
            for (KeyWrapper key : p.getKeys()) {
                if (!key.getKid().equals(kid) || !key.getStatus().isEnabled() || !this.matches(key, use, algorithm)) continue;
                if (logger.isTraceEnabled()) {
                    logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realm.getName(), key.getKid(), algorithm, use.name()});
                }
                return key;
            }
        }
        if (logger.isTraceEnabled()) {
            logger.tracev("Failed to find public key: realm={0} kid={1} algorithm={2} use={3}", new Object[]{realm.getName(), kid, algorithm, use.name()});
        }
        return null;
    }

    public List<KeyWrapper> getKeys(RealmModel realm, KeyUse use, String algorithm) {
        LinkedList<KeyWrapper> keys = new LinkedList<KeyWrapper>();
        for (KeyProvider p : this.getProviders(realm)) {
            for (KeyWrapper key : p.getKeys()) {
                if (!key.getStatus().isEnabled() || !this.matches(key, use, algorithm)) continue;
                keys.add(key);
            }
        }
        return keys;
    }

    public List<KeyWrapper> getKeys(RealmModel realm) {
        LinkedList<KeyWrapper> keys = new LinkedList<KeyWrapper>();
        for (KeyProvider p : this.getProviders(realm)) {
            for (KeyWrapper key : p.getKeys()) {
                keys.add(key);
            }
        }
        return keys;
    }

    @Deprecated
    public KeyManager.ActiveRsaKey getActiveRsaKey(RealmModel realm) {
        KeyWrapper key = this.getActiveKey(realm, KeyUse.SIG, "RS256");
        return new KeyManager.ActiveRsaKey(key.getKid(), (PrivateKey)key.getSignKey(), (PublicKey)key.getVerifyKey(), key.getCertificate());
    }

    @Deprecated
    public KeyManager.ActiveHmacKey getActiveHmacKey(RealmModel realm) {
        KeyWrapper key = this.getActiveKey(realm, KeyUse.SIG, "HS256");
        return new KeyManager.ActiveHmacKey(key.getKid(), key.getSecretKey());
    }

    @Deprecated
    public KeyManager.ActiveAesKey getActiveAesKey(RealmModel realm) {
        KeyWrapper key = this.getActiveKey(realm, KeyUse.ENC, "AES");
        return new KeyManager.ActiveAesKey(key.getKid(), key.getSecretKey());
    }

    @Deprecated
    public PublicKey getRsaPublicKey(RealmModel realm, String kid) {
        KeyWrapper key = this.getKey(realm, kid, KeyUse.SIG, "RS256");
        return key != null ? (PublicKey)key.getVerifyKey() : null;
    }

    @Deprecated
    public Certificate getRsaCertificate(RealmModel realm, String kid) {
        KeyWrapper key = this.getKey(realm, kid, KeyUse.SIG, "RS256");
        return key != null ? key.getCertificate() : null;
    }

    @Deprecated
    public SecretKey getHmacSecretKey(RealmModel realm, String kid) {
        KeyWrapper key = this.getKey(realm, kid, KeyUse.SIG, "HS256");
        return key != null ? key.getSecretKey() : null;
    }

    @Deprecated
    public SecretKey getAesSecretKey(RealmModel realm, String kid) {
        KeyWrapper key = this.getKey(realm, kid, KeyUse.ENC, "AES");
        return key.getSecretKey();
    }

    @Deprecated
    public List<RsaKeyMetadata> getRsaKeys(RealmModel realm) {
        LinkedList<RsaKeyMetadata> keys = new LinkedList<RsaKeyMetadata>();
        for (KeyWrapper key : this.getKeys(realm, KeyUse.SIG, "RS256")) {
            RsaKeyMetadata m = new RsaKeyMetadata();
            m.setCertificate((Certificate)key.getCertificate());
            m.setPublicKey((PublicKey)key.getVerifyKey());
            m.setKid(key.getKid());
            m.setProviderId(key.getProviderId());
            m.setProviderPriority(key.getProviderPriority());
            m.setStatus(key.getStatus());
            keys.add(m);
        }
        return keys;
    }

    public List<SecretKeyMetadata> getHmacKeys(RealmModel realm) {
        LinkedList<SecretKeyMetadata> keys = new LinkedList<SecretKeyMetadata>();
        for (KeyWrapper key : this.getKeys(realm, KeyUse.SIG, "HS256")) {
            SecretKeyMetadata m = new SecretKeyMetadata();
            m.setKid(key.getKid());
            m.setProviderId(key.getProviderId());
            m.setProviderPriority(key.getProviderPriority());
            m.setStatus(key.getStatus());
            keys.add(m);
        }
        return keys;
    }

    public List<SecretKeyMetadata> getAesKeys(RealmModel realm) {
        LinkedList<SecretKeyMetadata> keys = new LinkedList<SecretKeyMetadata>();
        for (KeyWrapper key : this.getKeys(realm, KeyUse.ENC, "AES")) {
            SecretKeyMetadata m = new SecretKeyMetadata();
            m.setKid(key.getKid());
            m.setProviderId(key.getProviderId());
            m.setProviderPriority(key.getProviderPriority());
            m.setStatus(key.getStatus());
            keys.add(m);
        }
        return keys;
    }

    private boolean matches(KeyWrapper key, KeyUse use, String algorithm) {
        return use.equals((Object)key.getUse()) && key.getAlgorithm().equals(algorithm);
    }

    private List<KeyProvider> getProviders(RealmModel realm) {
        List<KeyProvider> providers = this.providersMap.get(realm.getId());
        if (providers == null) {
            providers = new LinkedList<KeyProvider>();
            LinkedList<ComponentModel> components = new LinkedList<ComponentModel>(realm.getComponents(realm.getId(), KeyProvider.class.getName()));
            components.sort(new ProviderComparator());
            for (ComponentModel c : components) {
                try {
                    ProviderFactory f = this.session.getKeycloakSessionFactory().getProviderFactory(KeyProvider.class, c.getProviderId());
                    KeyProviderFactory factory = (KeyProviderFactory)f;
                    KeyProvider provider = factory.create(this.session, c);
                    this.session.enlistForClose((Provider)provider);
                    providers.add(provider);
                }
                catch (Throwable t) {
                    logger.errorv(t, "Failed to load provider {0}", (Object)c.getId());
                }
            }
            this.providersMap.put(realm.getId(), providers);
        }
        return providers;
    }

    private class ProviderComparator
    implements Comparator<ComponentModel> {
        private ProviderComparator() {
        }

        @Override
        public int compare(ComponentModel o1, ComponentModel o2) {
            int i = Long.compare(o2.get("priority", 0L), o1.get("priority", 0L));
            return i != 0 ? i : o1.getId().compareTo(o2.getId());
        }
    }
}

