/*
 * Decompiled with CFR 0.152.
 */
package org.picketlink.identity.federation.bindings.jboss.auth;

import java.security.KeyStore;
import java.security.Principal;
import java.security.PublicKey;
import java.security.acl.Group;
import java.security.cert.CertPath;
import java.security.cert.CertPathValidator;
import java.security.cert.CertPathValidatorResult;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.CertificateFactory;
import java.security.cert.CertificateNotYetValidException;
import java.security.cert.PKIXParameters;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.login.LoginException;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.namespace.NamespaceContext;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathFactory;
import org.apache.xml.security.Init;
import org.apache.xml.security.keys.KeyInfo;
import org.apache.xml.security.signature.XMLSignature;
import org.jboss.security.SimplePrincipal;
import org.jboss.security.auth.callback.ObjectCallback;
import org.picketlink.identity.federation.bindings.jboss.auth.SAMLTokenFromHttpRequestAbstractLoginModule;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkGroup;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkPrincipal;
import org.picketlink.identity.federation.core.exceptions.ConfigurationException;
import org.picketlink.identity.federation.core.exceptions.ProcessingException;
import org.picketlink.identity.federation.core.factories.JBossAuthCacheInvalidationFactory;
import org.picketlink.identity.federation.core.saml.v2.constants.JBossSAMLURIConstants;
import org.picketlink.identity.federation.core.saml.v2.util.AssertionUtil;
import org.picketlink.identity.federation.core.util.StringUtil;
import org.picketlink.identity.federation.core.wstrust.SamlCredential;
import org.picketlink.identity.federation.core.wstrust.auth.AbstractSTSLoginModule;
import org.picketlink.identity.federation.core.wstrust.plugins.saml.SAMLUtil;
import org.picketlink.identity.federation.saml.v2.assertion.AssertionType;
import org.picketlink.identity.federation.saml.v2.assertion.BaseIDAbstractType;
import org.picketlink.identity.federation.saml.v2.assertion.NameIDType;
import org.picketlink.identity.federation.saml.v2.assertion.SubjectType;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

public abstract class SAMLTokenCertValidatingCommonLoginModule
extends SAMLTokenFromHttpRequestAbstractLoginModule {
    protected Principal principal;
    protected SamlCredential credential;
    protected AssertionType assertion;
    protected boolean enableCacheInvalidation = false;
    protected String securityDomain = null;
    protected String localValidationSecurityDomain;
    protected String roleKey = "Role";
    protected Map<String, Object> options = new HashMap<String, Object>();
    protected Map<String, Object> rawOptions = new HashMap<String, Object>();
    public static final String STS_CONFIG_FILE = "configFile";
    public static final String ENDPOINT_ADDRESS = "endpointAddress";
    public static final String PORT_NAME = "portName";
    public static final String SERVICE_NAME = "serviceName";
    public static final String USERNAME_KEY = "username";
    public static final String PASSWORD_KEY = "password";
    protected boolean localTestingOnly = false;

    @Override
    public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
        String roleKeyStr;
        String cacheInvalidation;
        super.initialize(subject, callbackHandler, sharedState, options);
        this.options.putAll(options);
        this.rawOptions.putAll(options);
        if (logger.isTraceEnabled()) {
            logger.trace(options.toString());
        }
        if ((cacheInvalidation = (String)this.options.remove("cache.invalidation")) != null && !cacheInvalidation.isEmpty()) {
            this.enableCacheInvalidation = Boolean.parseBoolean(cacheInvalidation);
            this.securityDomain = (String)this.options.remove("jboss.security.security_domain");
            if (this.securityDomain == null || this.securityDomain.isEmpty()) {
                throw logger.optionNotSet("jboss.security.security_domain");
            }
        }
        if (StringUtil.isNotNull((String)(roleKeyStr = (String)options.get("roleKey")))) {
            this.roleKey = roleKeyStr.trim();
        }
        this.localValidationSecurityDomain = (String)options.get("localValidationSecurityDomain");
        if (this.localValidationSecurityDomain == null) {
            logger.error("PL00105: When using local validation 'localValidationSecurityDomain' must be specified.");
            throw logger.optionNotSet("localValidationSecurityDomain");
        }
        if (!this.localValidationSecurityDomain.startsWith("java:")) {
            this.localValidationSecurityDomain = "java:jboss/jaas//" + this.localValidationSecurityDomain;
        }
        Init.init();
    }

    public boolean login() throws LoginException {
        if (super.login()) {
            Object sharedPrincipal = this.sharedState.get("javax.security.auth.login.name");
            if (sharedPrincipal instanceof Principal) {
                this.principal = (Principal)sharedPrincipal;
            } else {
                try {
                    this.principal = this.createIdentity(sharedPrincipal.toString());
                }
                catch (Exception e) {
                    throw logger.authFailedToCreatePrincipal((Throwable)e);
                }
            }
            Object credential = this.sharedState.get("javax.security.auth.login.password");
            if (!(credential instanceof SamlCredential)) {
                throw logger.authSharedCredentialIsNotSAMLCredential(credential.getClass().getName());
            }
            this.credential = (SamlCredential)credential;
            return true;
        }
        ObjectCallback callback = new ObjectCallback(null);
        Element assertionElement = null;
        try {
            if (this.getSamlTokenHttpHeader() != null) {
                this.credential = this.getCredentialFromHttpRequest();
            } else {
                this.callbackHandler.handle(new Callback[]{callback});
                if (!(callback.getCredential() instanceof SamlCredential)) {
                    throw logger.authSharedCredentialIsNotSAMLCredential(callback.getCredential().getClass().getName());
                }
                this.credential = (SamlCredential)callback.getCredential();
            }
            assertionElement = this.credential.getAssertionAsElement();
        }
        catch (Exception e) {
            throw logger.authErrorHandlingCallback((Throwable)e);
        }
        try {
            this.assertion = SAMLUtil.fromElement((Element)assertionElement);
        }
        catch (Exception e) {
            throw logger.authFailedToParseSAMLAssertion((Throwable)e);
        }
        try {
            BaseIDAbstractType baseID;
            this.validateSAMLCredential();
            SubjectType subject = this.assertion.getSubject();
            if (subject != null && (baseID = subject.getSubType().getBaseID()) instanceof NameIDType) {
                NameIDType nameID = (NameIDType)baseID;
                this.principal = new PicketLinkPrincipal(nameID.getValue());
                if (this.enableCacheInvalidation) {
                    JBossAuthCacheInvalidationFactory.TimeCacheExpiry cacheExpiry = this.getCacheExpiry();
                    XMLGregorianCalendar expiry = AssertionUtil.getExpiration((AssertionType)this.assertion);
                    if (expiry != null) {
                        Date expiryDate = expiry.toGregorianCalendar().getTime();
                        logger.trace("Creating Cache Entry for JBoss at [" + new Date() + "] , with expiration set to SAML expiry = " + expiryDate);
                        cacheExpiry.register(this.securityDomain, expiryDate, this.principal);
                    } else {
                        logger.samlAssertionWithoutExpiration(this.assertion.getID());
                    }
                }
            }
        }
        catch (Throwable e) {
            logger.error(e);
            LoginException le = new LoginException(e.getMessage());
            throw le;
        }
        if (this.getUseFirstPass()) {
            this.sharedState.put("javax.security.auth.login.name", this.principal);
            this.sharedState.put("javax.security.auth.login.password", this.credential);
        }
        this.loginOk = true;
        return true;
    }

    public boolean commit() throws LoginException {
        if (super.commit()) {
            boolean added = this.subject.getPublicCredentials().add(this.credential);
            if (added && logger.isTraceEnabled()) {
                logger.trace("Added Credential " + this.credential);
            }
            return true;
        }
        return false;
    }

    public boolean abort() throws LoginException {
        this.clearState();
        super.abort();
        return true;
    }

    public boolean logout() throws LoginException {
        this.clearState();
        super.logout();
        return true;
    }

    private void clearState() {
        AbstractSTSLoginModule.removeAllSamlCredentials((Subject)this.subject);
        this.credential = null;
    }

    protected Principal getIdentity() {
        return this.principal;
    }

    protected Group[] getRoleSets() throws LoginException {
        if (this.assertion == null) {
            try {
                this.assertion = SAMLUtil.fromElement((Element)this.credential.getAssertionAsElement());
            }
            catch (Exception e) {
                throw logger.authFailedToParseSAMLAssertion((Throwable)e);
            }
        }
        if (logger.isTraceEnabled()) {
            try {
                logger.trace("Assertion from where roles will be sought = " + AssertionUtil.asString((AssertionType)this.assertion));
            }
            catch (ProcessingException ignore) {
                // empty catch block
            }
        }
        ArrayList roleKeys = new ArrayList();
        if (StringUtil.isNotNull((String)this.roleKey)) {
            roleKeys.addAll(StringUtil.tokenize((String)this.roleKey));
        }
        String groupName = "Roles";
        PicketLinkGroup rolesGroup = new PicketLinkGroup(groupName);
        List roles = AssertionUtil.getRoles((AssertionType)this.assertion, roleKeys);
        for (String role : roles) {
            rolesGroup.addMember((Principal)new SimplePrincipal(role));
        }
        return new Group[]{rolesGroup};
    }

    protected JBossAuthCacheInvalidationFactory.TimeCacheExpiry getCacheExpiry() throws Exception {
        return JBossAuthCacheInvalidationFactory.getCacheExpiry();
    }

    private void validateSAMLCredential() throws LoginException, ConfigurationException, CertificateExpiredException, CertificateNotYetValidException {
        X509Certificate cert = this.getX509Certificate();
        this.validateCertPath(cert);
        cert.checkValidity();
        boolean sigValid = false;
        try {
            sigValid = AssertionUtil.isSignatureValid((Element)this.credential.getAssertionAsElement(), (PublicKey)cert.getPublicKey());
        }
        catch (ProcessingException e) {
            logger.processingError((Throwable)e);
        }
        if (!sigValid) {
            throw logger.authSAMLInvalidSignatureError();
        }
        if (AssertionUtil.hasExpired((AssertionType)this.assertion)) {
            throw logger.authSAMLAssertionExpiredError();
        }
    }

    private X509Certificate getX509Certificate() throws LoginException {
        try {
            KeyInfo keyInfo;
            Element assertion = this.credential.getAssertionAsElement();
            String xmlSignatureNSPrefix = this.findNameSpacePrefix(assertion, JBossSAMLURIConstants.XMLDSIG_NSURI.get());
            String expression = "//" + xmlSignatureNSPrefix + ":Signature[1]";
            XPathFactory xpf = XPathFactory.newInstance();
            XPath xpath = xpf.newXPath();
            xpath.setNamespaceContext((NamespaceContext)org.picketlink.identity.federation.core.util.NamespaceContext.create().addNsUriPair(xmlSignatureNSPrefix, JBossSAMLURIConstants.XMLDSIG_NSURI.get()));
            Element sigElement = (Element)xpath.evaluate(expression, this.credential.getAssertionAsElement(), XPathConstants.NODE);
            XMLSignature signature = new XMLSignature(sigElement, "");
            if (logger.isTraceEnabled()) {
                logger.trace("sigElement=" + sigElement.getTextContent());
            }
            if (!(keyInfo = signature.getKeyInfo()).containsX509Data()) {
                this.log.error((Object)"Cannot find X509Data element");
                throw new LoginException("Cannot find X509Data element");
            }
            X509Certificate certificate = signature.getKeyInfo().getX509Certificate();
            if (certificate == null) {
                logger.error("Not able to extract x509 certificate");
                throw new LoginException("Not able to extract x509 certificate");
            }
            if (logger.isTraceEnabled()) {
                logger.trace("Got certificate=" + certificate.toString());
            }
            return certificate;
        }
        catch (Exception e) {
            logger.error((Throwable)e);
            throw new LoginException(e.getLocalizedMessage());
        }
    }

    private String findNameSpacePrefix(Element element, String xmlns) {
        NodeList nl = element.getElementsByTagNameNS(xmlns, "Signature");
        if (nl.getLength() > 0) {
            return nl.item(0).getPrefix();
        }
        return null;
    }

    protected void validateCertPath(X509Certificate certificate) throws LoginException {
        CertPath certPath = null;
        try {
            CertificateFactory certFact = CertificateFactory.getInstance("X.509");
            certPath = certFact.generateCertPath(Arrays.asList(certificate));
        }
        catch (CertificateEncodingException e) {
            logger.error(e.getMessage());
            throw new LoginException(e.getLocalizedMessage());
        }
        catch (CertificateException e) {
            logger.error(e.getMessage());
            throw new LoginException(e.getLocalizedMessage());
        }
        if (logger.isTraceEnabled()) {
            logger.trace("Certificates from SAML token:");
            for (Certificate certificate2 : certPath.getCertificates()) {
                logger.trace("Type of certificate=" + certificate2.getType());
                logger.trace(certificate2.toString());
            }
        }
        try {
            KeyStore trustStore = this.getKeyStore();
            if (trustStore == null) {
                throw logger.authNullKeyStoreFromSecurityDomainError(this.localValidationSecurityDomain);
            }
            if (logger.isTraceEnabled()) {
                logger.trace("Certificates from truststore:");
                Enumeration<String> enumeration = trustStore.aliases();
                while (enumeration.hasMoreElements()) {
                    Certificate crt;
                    String alias = enumeration.nextElement();
                    logger.trace("Alias=" + alias);
                    Certificate[] chain = trustStore.getCertificateChain(alias);
                    if (chain != null) {
                        logger.trace(alias + " is a chain:");
                        for (Certificate c : chain) {
                            logger.trace(c.toString());
                        }
                    }
                    if ((crt = trustStore.getCertificate(alias)) == null) continue;
                    logger.trace(alias + " is a certificate of type " + crt.getType());
                    logger.trace(crt.toString());
                }
            }
            PKIXParameters pKIXParameters = new PKIXParameters(trustStore);
            pKIXParameters.setRevocationEnabled(false);
            CertPathValidator certPathValidator = CertPathValidator.getInstance(CertPathValidator.getDefaultType());
            if (logger.isTraceEnabled()) {
                logger.trace("certPathValidator is ready");
            }
            CertPathValidatorResult result = certPathValidator.validate(certPath, pKIXParameters);
            if (logger.isTraceEnabled()) {
                logger.trace("CertPathValidatorResult=" + result);
            }
        }
        catch (Exception e) {
            logger.error((Throwable)e);
            throw new LoginException(e.getLocalizedMessage());
        }
    }

    protected abstract KeyStore getKeyStore() throws Exception;
}

