/*
 * Decompiled with CFR 0.152.
 */
package code.ponfee.commons.jce.sm;

import code.ponfee.commons.jce.ECParameters;
import code.ponfee.commons.jce.sm.SM2;
import code.ponfee.commons.jce.sm.SM3Digest;
import code.ponfee.commons.util.Bytes;
import code.ponfee.commons.util.SecureRandoms;
import java.io.ByteArrayOutputStream;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Arrays;
import org.bouncycastle.math.ec.ECPoint;

public class SM2KeyExchanger
implements Serializable {
    private static final long serialVersionUID = 8553046425593791291L;
    private static final BigInteger TWO = BigInteger.valueOf(2L);
    private BigInteger rA;
    private ECPoint RA;
    private ECPoint V;
    private byte[] key;
    private final ECParameters ecParam;
    private final BigInteger w;
    private final ECPoint publicKey;
    private final BigInteger privateKey;
    private final byte[] Z;

    public SM2KeyExchanger(ECPoint publicKey, BigInteger privateKey) {
        this(null, publicKey, privateKey, ECParameters.SM2_BEST);
    }

    public SM2KeyExchanger(byte[] ida, ECPoint publicKey, BigInteger privateKey) {
        this(ida, publicKey, privateKey, ECParameters.SM2_BEST);
    }

    public SM2KeyExchanger(ECPoint publicKey, BigInteger privateKey, ECParameters ecParam) {
        this(null, publicKey, privateKey, ecParam);
    }

    public SM2KeyExchanger(byte[] ida, ECPoint publicKey, BigInteger privateKey, ECParameters ecParam) {
        this.ecParam = ecParam;
        this.w = TWO.pow((int)Math.ceil((double)ecParam.n.bitLength() * 1.0 / 2.0) - 1);
        this.publicKey = publicKey;
        this.privateKey = privateKey;
        this.Z = SM2.calcZ(SM3Digest.getInstance(), ecParam, ida, publicKey);
    }

    public TransportEntity step1PartA() {
        this.rA = SecureRandoms.random(this.ecParam.n);
        this.RA = this.ecParam.pointG.multiply(this.rA).normalize();
        return new TransportEntity(this.RA.getEncoded(false), null, this.Z, this.publicKey);
    }

    public TransportEntity step2PartB(TransportEntity entity1) {
        BigInteger rB = SecureRandoms.random(this.ecParam.n);
        ECPoint RB = this.ecParam.pointG.multiply(rB).normalize();
        this.rA = rB;
        this.RA = RB;
        BigInteger x2 = RB.getXCoord().toBigInteger();
        x2 = this.w.add(x2.and(this.w.subtract(BigInteger.ONE)));
        BigInteger tB = this.privateKey.add(x2.multiply(rB)).mod(this.ecParam.n);
        ECPoint RA = this.ecParam.curve.decodePoint(entity1.R).normalize();
        BigInteger x1 = RA.getXCoord().toBigInteger();
        x1 = this.w.add(x1.and(this.w.subtract(BigInteger.ONE)));
        ECPoint aPublicKey = this.ecParam.curve.decodePoint(entity1.K).normalize();
        ECPoint temp = aPublicKey.add(RA.multiply(x1).normalize()).normalize();
        ECPoint V = temp.multiply(this.ecParam.bcSpec.getH().multiply(tB)).normalize();
        if (V.isInfinity()) {
            throw new IllegalStateException();
        }
        this.V = V;
        byte[] xV = V.getXCoord().toBigInteger().toByteArray();
        byte[] yV = V.getYCoord().toBigInteger().toByteArray();
        this.key = SM2KeyExchanger.kdf(Bytes.concat(xV, yV, entity1.Z, this.Z), 16);
        SM3Digest sm3 = SM3Digest.getInstance();
        byte[] data = SM2KeyExchanger.digest(sm3, xV, entity1.Z, this.Z, RA, RB);
        sm3.update((byte)2);
        sm3.update(yV);
        sm3.update(data);
        byte[] sB = sm3.doFinal();
        return new TransportEntity(RB.getEncoded(false), sB, this.Z, this.publicKey);
    }

    public TransportEntity step3PartA(TransportEntity entity2) {
        BigInteger x1 = this.RA.getXCoord().toBigInteger();
        x1 = this.w.add(x1.and(this.w.subtract(BigInteger.ONE)));
        BigInteger tA = this.privateKey.add(x1.multiply(this.rA)).mod(this.ecParam.n);
        ECPoint RB = this.ecParam.curve.decodePoint(entity2.R).normalize();
        BigInteger x2 = RB.getXCoord().toBigInteger();
        x2 = this.w.add(x2.and(this.w.subtract(BigInteger.ONE)));
        ECPoint bPublicKey = this.ecParam.curve.decodePoint(entity2.K).normalize();
        ECPoint temp = bPublicKey.add(RB.multiply(x2).normalize()).normalize();
        ECPoint U = temp.multiply(this.ecParam.bcSpec.getH().multiply(tA)).normalize();
        if (U.isInfinity()) {
            throw new IllegalStateException();
        }
        this.V = U;
        byte[] xU = U.getXCoord().toBigInteger().toByteArray();
        byte[] yU = U.getYCoord().toBigInteger().toByteArray();
        this.key = SM2KeyExchanger.kdf(Bytes.concat(xU, yU, this.Z, entity2.Z), 16);
        SM3Digest sm3 = SM3Digest.getInstance();
        byte[] data = SM2KeyExchanger.digest(sm3, xU, this.Z, entity2.Z, this.RA, RB);
        sm3.update((byte)2);
        sm3.update(yU);
        sm3.update(data);
        data = sm3.doFinal();
        if (!Arrays.equals(entity2.S, data)) {
            return null;
        }
        data = SM2KeyExchanger.digest(sm3, xU, this.Z, entity2.Z, this.RA, RB);
        sm3.update((byte)3);
        sm3.update(yU);
        sm3.update(data);
        byte[] sA = sm3.doFinal();
        return new TransportEntity(this.RA.getEncoded(false), sA, this.Z, this.publicKey);
    }

    public boolean step4PartB(TransportEntity entity3) {
        byte[] xV = this.V.getXCoord().toBigInteger().toByteArray();
        byte[] yV = this.V.getYCoord().toBigInteger().toByteArray();
        ECPoint RA = this.ecParam.curve.decodePoint(entity3.R).normalize();
        SM3Digest sm3 = SM3Digest.getInstance();
        byte[] data = SM2KeyExchanger.digest(sm3, xV, entity3.Z, this.Z, RA, this.RA);
        sm3.update((byte)3);
        sm3.update(yV);
        sm3.update(data);
        return Arrays.equals(entity3.S, sm3.doFinal());
    }

    public byte[] getKey() {
        return this.key;
    }

    private static byte[] digest(SM3Digest sm3, byte[] x, byte[] z1, byte[] z2, ECPoint a, ECPoint b) {
        sm3.reset();
        sm3.update(x);
        sm3.update(z1);
        sm3.update(z2);
        sm3.update(a.getXCoord().toBigInteger().toByteArray());
        sm3.update(a.getYCoord().toBigInteger().toByteArray());
        sm3.update(b.getXCoord().toBigInteger().toByteArray());
        sm3.update(b.getYCoord().toBigInteger().toByteArray());
        return sm3.doFinal();
    }

    private static byte[] kdf(byte[] Z, int klen) {
        int ct = 1;
        int end = (int)Math.ceil((double)klen * 1.0 / 32.0);
        SM3Digest sm3 = SM3Digest.getInstance();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        for (int i = 1; i < end; ++i) {
            sm3.update(Z);
            byte[] data = sm3.doFinal(Bytes.toBytes(ct));
            baos.write(data, 0, data.length);
            ++ct;
        }
        sm3.update(Z);
        sm3.update(Bytes.toBytes(ct));
        byte[] last = sm3.doFinal();
        int len = klen & 0x1F;
        baos.write(last, 0, len == 0 ? last.length : len);
        return baos.toByteArray();
    }

    public static class TransportEntity
    implements Serializable {
        private static final long serialVersionUID = 3657694935421411649L;
        private final byte[] R;
        private final byte[] S;
        private final byte[] Z;
        private final byte[] K;

        TransportEntity(byte[] r, byte[] s, byte[] z, ECPoint pKey) {
            this(r, s, z, pKey.getEncoded(false));
        }

        TransportEntity(byte[] r, byte[] s, byte[] z, byte[] publicKey) {
            this.R = r;
            this.S = s;
            this.Z = z;
            this.K = publicKey;
        }

        public byte[] getR() {
            return this.R;
        }

        public byte[] getS() {
            return this.S;
        }

        public byte[] getZ() {
            return this.Z;
        }

        public byte[] getK() {
            return this.K;
        }
    }
}

