/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.security.sasl.scram;

import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.SaslException;
import org.wildfly.security.sasl.scram.Scram;
import org.wildfly.security.sasl.scram.ScramUtil;
import org.wildfly.security.sasl.util.AbstractSaslClient;
import org.wildfly.security.sasl.util.StringPrep;
import org.wildfly.security.util.ByteIterator;
import org.wildfly.security.util.ByteStringBuilder;

class ScramSaslClient
extends AbstractSaslClient {
    private static final int ST_NEW = 1;
    private static final int ST_R1_SENT = 2;
    private static final int ST_R2_SENT = 3;
    private final int minimumIterationCount;
    private final int maximumIterationCount;
    private final MessageDigest messageDigest;
    private final Mac mac;
    private final SecureRandom secureRandom;
    private final boolean plus;
    private final byte[] bindingData;
    private final String bindingType;
    private byte[] clientFirstMessage;
    private int bareStart;
    private byte[] clientFinalMessage;
    private byte[] nonce;
    private PasswordCallback passwordCallback;
    private int proofStart;
    private byte[] saltedPassword;
    private byte[] serverFirstMessage;
    private static final boolean DEBUG = true;

    ScramSaslClient(String mechanismName, MessageDigest messageDigest, Mac mac, SecureRandom secureRandom, String protocol, String serverName, CallbackHandler callbackHandler, String authorizationId, Map<String, ?> props, boolean plus, String bindingType, byte[] bindingData) {
        super(mechanismName, protocol, serverName, callbackHandler, authorizationId, true);
        this.bindingType = bindingType;
        this.minimumIterationCount = this.getIntProperty(props, "wildfly.sasl.scram.min-iteration-count", 4096);
        this.maximumIterationCount = this.getIntProperty(props, "wildfly.sasl.scram.max-iteration-count", 32768);
        this.secureRandom = secureRandom;
        this.messageDigest = messageDigest;
        this.mac = mac;
        this.plus = plus;
        this.bindingData = bindingData;
    }

    MessageDigest getMessageDigest() {
        return this.messageDigest;
    }

    @Override
    public void dispose() throws SaslException {
        this.messageDigest.reset();
        this.setNegotiationState(-1);
    }

    @Override
    public void init() {
        this.setNegotiationState(1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected byte[] evaluateMessage(int state, byte[] challenge) throws SaslException {
        switch (state) {
            case 1: {
                if (challenge.length != 0) {
                    throw new SaslException("Initial challenge must be empty");
                }
                ByteStringBuilder b = new ByteStringBuilder();
                String authorizationId = this.getAuthorizationId();
                NameCallback nameCallback = authorizationId == null ? new NameCallback("User name") : new NameCallback("User name", authorizationId);
                Callback[] callbackArray = new Callback[2];
                callbackArray[0] = nameCallback;
                this.passwordCallback = new PasswordCallback("Password", false);
                callbackArray[1] = this.passwordCallback;
                this.handleCallbacks(callbackArray);
                if (this.bindingData != null) {
                    if (this.plus) {
                        b.append("p=");
                        b.append(this.bindingType);
                        b.append(',');
                    } else {
                        b.append("y,");
                    }
                } else {
                    b.append("n,");
                }
                if (authorizationId != null) {
                    b.append('a').append('=');
                    StringPrep.encode(authorizationId, b, 1073758207L);
                }
                b.append(',');
                this.bareStart = b.length();
                b.append('n').append('=');
                StringPrep.encode(nameCallback.getName(), b, 1073758207L);
                b.append(',').append('r').append('=');
                Random random = this.secureRandom != null ? this.secureRandom : ThreadLocalRandom.current();
                this.nonce = ScramUtil.generateNonce(48, random);
                b.append(this.nonce);
                System.out.printf("[C] Client nonce: %s%n", ByteIterator.ofBytes(this.nonce).hexEncode().drainToString());
                this.setNegotiationState(2);
                System.out.printf("[C] Client first message: %s%n", ByteIterator.ofBytes(b.toArray()).hexEncode().drainToString());
                this.clientFirstMessage = b.toArray();
                return this.clientFirstMessage;
            }
            case 2: {
                this.serverFirstMessage = challenge;
                ByteIterator bi = ByteIterator.ofBytes(challenge);
                ByteIterator di = bi.delimitedBy(44);
                System.out.printf("[C] Server first message: %s%n", ByteIterator.ofBytes(challenge).hexEncode().drainToString());
                ByteStringBuilder b = new ByteStringBuilder();
                Mac mac = this.mac;
                MessageDigest messageDigest = this.messageDigest;
                try {
                    if (bi.next() == 114 && bi.next() == 61) {
                        if (!di.limitedTo(this.nonce.length).contentEquals(ByteIterator.ofBytes(this.nonce))) {
                            throw new SaslException("Nonces do not match");
                        }
                        byte[] serverNonce = di.drain();
                        if (serverNonce.length < 18) {
                            throw new SaslException("Server nonce is too short");
                        }
                        bi.next();
                        if (bi.next() == 115 && bi.next() == 61) {
                            byte[] salt = di.base64Decode().drain();
                            bi.next();
                            System.out.printf("[C] Server sent salt: %s%n", ByteIterator.ofBytes(salt).hexEncode().drainToString());
                            if (bi.next() == 105 && bi.next() == 61) {
                                int iterationCount = ScramUtil.parsePosInt(di);
                                if (iterationCount < this.minimumIterationCount) {
                                    throw new SaslException("Iteration count is too low");
                                }
                                if (iterationCount > this.maximumIterationCount) {
                                    throw new SaslException("Iteration count is too high");
                                }
                                if (bi.hasNext()) {
                                    if (bi.next() == 44) {
                                        throw new SaslException("Extensions unsupported");
                                    }
                                    throw new SaslException("Invalid server message");
                                }
                                b.append('c').append('=');
                                ByteStringBuilder b2 = new ByteStringBuilder();
                                if (this.bindingData != null) {
                                    System.out.printf("[C] Binding data: %s%n", ByteIterator.ofBytes(this.bindingData).hexEncode().drainToString());
                                    if (this.plus) {
                                        b2.append("p=");
                                        b2.append(this.bindingType);
                                    } else {
                                        b2.append('y');
                                    }
                                    b2.append(',');
                                    if (this.getAuthorizationId() != null) {
                                        b2.append("a=").append(this.getAuthorizationId());
                                    }
                                    b2.append(',');
                                    if (this.plus) {
                                        b2.append(this.bindingData);
                                    }
                                    b.appendLatin1(b2.iterate().base64Encode());
                                } else {
                                    b2.append('n');
                                    b2.append(',');
                                    if (this.getAuthorizationId() != null) {
                                        b2.append("a=").append(this.getAuthorizationId());
                                    }
                                    b2.append(',');
                                    assert (!this.plus);
                                    b.appendLatin1(b2.iterate().base64Encode());
                                }
                                b.append(',').append('r').append('=').append(this.nonce).append(serverNonce);
                                this.saltedPassword = ScramUtil.calculateHi(mac, this.passwordCallback.getPassword(), salt, 0, salt.length, iterationCount);
                                System.out.printf("[C] Client salted password: %s%n", ByteIterator.ofBytes(this.saltedPassword).hexEncode().drainToString());
                                mac.init(new SecretKeySpec(this.saltedPassword, mac.getAlgorithm()));
                                byte[] clientKey = mac.doFinal(Scram.CLIENT_KEY_BYTES);
                                System.out.printf("[C] Client key: %s%n", ByteIterator.ofBytes(clientKey).hexEncode().drainToString());
                                byte[] storedKey = messageDigest.digest(clientKey);
                                System.out.printf("[C] Stored key: %s%n", ByteIterator.ofBytes(storedKey).hexEncode().drainToString());
                                mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
                                mac.update(this.clientFirstMessage, this.bareStart, this.clientFirstMessage.length - this.bareStart);
                                System.out.printf("[C] Using client first message: %s%n", ByteIterator.ofBytes(Arrays.copyOfRange(this.clientFirstMessage, this.bareStart, this.clientFirstMessage.length)).hexEncode().drainToString());
                                mac.update((byte)44);
                                mac.update(challenge);
                                System.out.printf("[C] Using server first message: %s%n", ByteIterator.ofBytes(challenge).hexEncode().drainToString());
                                mac.update((byte)44);
                                b.updateMac(mac);
                                System.out.printf("[C] Using client final message without proof: %s%n", ByteIterator.ofBytes(b.toArray()).hexEncode().drainToString());
                                byte[] clientProof = mac.doFinal();
                                System.out.printf("[C] Client signature: %s%n", ByteIterator.ofBytes(clientProof).hexEncode().drainToString());
                                ScramUtil.xor(clientProof, clientKey);
                                System.out.printf("[C] Client proof: %s%n", ByteIterator.ofBytes(clientProof).hexEncode().drainToString());
                                this.proofStart = b.length();
                                b.append(',').append('p').append('=');
                                b.appendLatin1(ByteIterator.ofBytes(clientProof).base64Encode());
                                this.setNegotiationState(3);
                                System.out.printf("[C] Client final message: %s%n", ByteIterator.ofBytes(b.toArray()).hexEncode().drainToString());
                                this.clientFinalMessage = b.toArray();
                                byte[] byArray = this.clientFinalMessage;
                                return byArray;
                            }
                        }
                    }
                }
                catch (ArrayIndexOutOfBoundsException | IllegalArgumentException | InvalidKeyException ignored) {
                    throw new SaslException("Invalid server message");
                }
                finally {
                    messageDigest.reset();
                    mac.reset();
                }
                throw new SaslException("Invalid server message");
            }
            case 3: {
                System.out.printf("[C] Server final message: %s%n", new String(challenge, StandardCharsets.UTF_8));
                Mac mac = this.mac;
                MessageDigest messageDigest = this.messageDigest;
                ByteIterator bi = ByteIterator.ofBytes(challenge);
                ByteIterator di = bi.delimitedBy(44);
                try {
                    int c = bi.next();
                    if (c == 101) {
                        if (bi.next() == 61) {
                            throw new SaslException("Server rejected authentication: " + di.asUtf8String().drainToString());
                        }
                        throw new SaslException("Server rejected authentication");
                    }
                    if (c == 118 && bi.next() == 61) {
                        mac.init(new SecretKeySpec(this.saltedPassword, mac.getAlgorithm()));
                        byte[] serverKey = mac.doFinal(Scram.SERVER_KEY_BYTES);
                        System.out.printf("[C] Server key: %s%n", ByteIterator.ofBytes(serverKey).hexEncode().drainToString());
                        mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
                        mac.update(this.clientFirstMessage, this.bareStart, this.clientFirstMessage.length - this.bareStart);
                        mac.update((byte)44);
                        mac.update(this.serverFirstMessage);
                        mac.update((byte)44);
                        mac.update(this.clientFinalMessage, 0, this.proofStart);
                        byte[] serverSignature = mac.doFinal();
                        System.out.printf("[C] Recovered server signature: %s%n", ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());
                        if (!di.base64Decode().contentEquals(ByteIterator.ofBytes(serverSignature))) {
                            this.setNegotiationState(-1);
                            throw new SaslException("Server authenticity cannot be verified");
                        }
                        this.setNegotiationState(0);
                        byte[] byArray = null;
                        return byArray;
                    }
                }
                catch (IllegalArgumentException | InvalidKeyException exception) {
                }
                finally {
                    messageDigest.reset();
                    mac.reset();
                }
                this.setNegotiationState(-1);
                throw new SaslException("Invalid server message");
            }
        }
        throw new IllegalStateException();
    }
}

