001package runwar;
002
003import org.xnio.IoUtils;
004
005import java.io.BufferedReader;
006import java.io.ByteArrayInputStream;
007import java.io.DataInputStream;
008import java.io.File;
009import java.io.FileInputStream;
010import java.io.FileReader;
011import java.io.IOException;
012import java.io.InputStream;
013import java.security.KeyFactory;
014import java.security.KeyManagementException;
015import java.security.KeyStore;
016import java.security.KeyStoreException;
017import java.security.NoSuchAlgorithmException;
018import java.security.PrivateKey;
019import java.security.Security;
020import java.security.UnrecoverableKeyException;
021import java.security.cert.Certificate;
022import java.security.cert.CertificateFactory;
023import java.security.spec.PKCS8EncodedKeySpec;
024import java.util.Arrays;
025import java.util.Collection;
026
027import javax.net.ssl.KeyManager;
028import javax.net.ssl.KeyManagerFactory;
029import javax.net.ssl.SSLContext;
030import javax.net.ssl.TrustManager;
031import javax.net.ssl.TrustManagerFactory;
032
033import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
034import org.bouncycastle.jce.provider.BouncyCastleProvider;
035import org.bouncycastle.openssl.PEMParser;
036import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
037
038import runwar.logging.Logger;
039
040public class SSLUtil {
041        
042        private static Logger log = Logger.getLogger("RunwarLogger");
043        private static final String SERVER_KEY_STORE = "runwar/runwar.keystore";
044    private static final String SERVER_TRUST_STORE = "runwar/runwar.truststore";
045    private static final char[] STORE_PASSWORD = "password".toCharArray();
046    
047        public static SSLContext createSSLContext() throws IOException {
048                log.debug("Creating SSL context from: " + SERVER_KEY_STORE + " trust store: " + SERVER_TRUST_STORE);
049                return createSSLContext(loadKeyStore(SERVER_KEY_STORE), loadKeyStore(SERVER_TRUST_STORE));
050        }
051
052        public static SSLContext createSSLContext(File certfile, File keyfile, char[] passphrase) throws IOException {
053                log.debug("Creating SSL context from cert: " + certfile + " key: " + keyfile);
054                
055                SSLContext context = null;
056                try {
057                        context = createSSLContext(keystoreFromDERCertificate(certfile, keyfile, passphrase), loadKeyStore(SERVER_TRUST_STORE));
058                } catch (Exception e) {
059                        throw new IOException("Could not load certificate",e);
060                }
061                return context;
062        }
063        
064    private static SSLContext createSSLContext(final KeyStore keyStore, final KeyStore trustStore) throws IOException {
065        KeyManager[] keyManagers;
066        try {
067            KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
068            keyManagerFactory.init(keyStore, STORE_PASSWORD);
069            keyManagers = keyManagerFactory.getKeyManagers();
070        } catch (NoSuchAlgorithmException e) {
071            throw new IOException("Unable to initialise KeyManager[]", e);
072        } catch (UnrecoverableKeyException e) {
073            throw new IOException("Unable to initialise KeyManager[]", e);
074        } catch (KeyStoreException e) {
075            throw new IOException("Unable to initialise KeyManager[]", e);
076        }
077
078        TrustManager[] trustManagers = null;
079        try {
080            TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
081            trustManagerFactory.init(trustStore);
082            trustManagers = trustManagerFactory.getTrustManagers();
083        } catch (NoSuchAlgorithmException e) {
084            throw new IOException("Unable to initialise TrustManager[]", e);
085        } catch (KeyStoreException e) {
086            throw new IOException("Unable to initialise TrustManager[]", e);
087        }
088
089        SSLContext sslContext;
090        try {
091            sslContext = SSLContext.getInstance("TLS");
092            sslContext.init(keyManagers, trustManagers, null);
093        } catch (NoSuchAlgorithmException e) {
094            throw new IOException("Unable to create and initialise the SSLContext", e);
095        } catch (KeyManagementException e) {
096            throw new IOException("Unable to create and initialise the SSLContext", e);
097        }
098
099        return sslContext;
100    }
101
102        private static KeyStore loadKeyStore(final String name) throws IOException {
103        final InputStream stream = SSLUtil.class.getClassLoader().getResourceAsStream(name);
104        if(stream == null)
105            throw new IOException(String.format("Unable to load KeyStore from classpath %s", name));
106        try {
107            KeyStore loadedKeystore = KeyStore.getInstance("JKS");
108            loadedKeystore.load(stream, STORE_PASSWORD);
109            log.debug("loaded store: " + name);
110            return loadedKeystore;
111        } catch (Exception e) {
112            throw new IOException(String.format("Unable to load KeyStore %s", name), e);
113        } finally {
114            IoUtils.safeClose(stream);
115        }
116    }
117
118        public static KeyStore keystoreFromDERCertificate ( File certfile, File keyfile, char[] passphrase) throws Exception {
119        
120        String defaultalias = "serverkey";
121        PrivateKey ff;
122        KeyStore ks = KeyStore.getInstance("JKS", "SUN");
123        ks.load( null , passphrase);
124        try {
125                // try the pks8 java format first
126                ff = loadPKCS8PrivateKey(keyfile);
127        } catch (Exception e) {
128                // use the rsa format from openssl
129                ff = loadRSAPrivateKey(keyfile);
130        }
131
132        CertificateFactory cf = CertificateFactory.getInstance("X.509");
133        InputStream certstream = fullStream (certfile);
134        Collection<?> c = cf.generateCertificates(certstream) ;
135        Certificate[] certs = new Certificate[c.toArray().length];
136
137        if (c.size() == 1) {
138            certstream = fullStream (certfile);
139            log.debug("One certificate, no chain.");
140            Certificate cert = cf.generateCertificate(certstream) ;
141            certs[0] = cert;
142        } else {
143                log.debug("Certificate chain length: "+c.size());
144            certs = (Certificate[])c.toArray();
145        }
146        ks.setKeyEntry(defaultalias, ff, 
147                       passphrase,
148                       certs );
149        Arrays.fill(passphrase, '*');
150        return ks;
151    }
152        
153        private static PrivateKey loadPKCS8PrivateKey(File f) throws Exception {
154                    FileInputStream fis = new FileInputStream(f);
155                    DataInputStream dis = new DataInputStream(fis);
156                    byte[] keyBytes = new byte[(int) f.length()];
157                    dis.readFully(keyBytes);
158                    dis.close();
159                    PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
160                    KeyFactory kf = KeyFactory.getInstance("RSA");
161                    return kf.generatePrivate(spec);
162        }
163        
164        private static PrivateKey loadRSAPrivateKey(File f) throws Exception {
165                BufferedReader br = new BufferedReader(new FileReader(f));
166                Security.addProvider(new BouncyCastleProvider());
167                PEMParser pp = new PEMParser(br);
168                PrivateKeyInfo pemKeyPair = (PrivateKeyInfo) pp.readObject();
169                PrivateKey kp = new JcaPEMKeyConverter().getPrivateKey(pemKeyPair);
170                pp.close();
171                return kp;
172        }
173        
174        private static InputStream fullStream ( File fname ) throws IOException {
175        FileInputStream fis = new FileInputStream(fname);
176        DataInputStream dis = new DataInputStream(fis);
177        byte[] bytes = new byte[dis.available()];
178        dis.readFully(bytes);
179        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
180        IoUtils.safeClose(fis);
181        IoUtils.safeClose(dis);
182        return bais;
183    }   
184}