001package top.cenze.utils.crypt.sm.sm2;
002
003import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
004import org.bouncycastle.crypto.digests.SM3Digest;
005import org.bouncycastle.crypto.generators.ECKeyPairGenerator;
006import org.bouncycastle.crypto.params.ECDomainParameters;
007import org.bouncycastle.crypto.params.ECKeyGenerationParameters;
008import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
009import org.bouncycastle.crypto.params.ECPublicKeyParameters;
010import org.bouncycastle.math.ec.ECCurve;
011import org.bouncycastle.math.ec.ECFieldElement;
012import org.bouncycastle.math.ec.ECFieldElement.Fp;
013import org.bouncycastle.math.ec.ECPoint;
014import top.cenze.utils.ConvertUtil;
015
016import java.math.BigInteger;
017import java.security.SecureRandom;
018
019
020public class SM2Factory {
021        /*-----------------------国密算法相关参数begin-----------
022         * ------------------*/
023        //A 第一系数
024        private static final BigInteger a  = new BigInteger("fffffffeffffffffffffffffffffffffffffffff00000000fffffffffffffffc",16);
025        //B 第二系数
026        private static final BigInteger b  = new BigInteger("28e9fa9e9d9f5e344d5a9e4bcf6509a7f39789f515ab8f92ddbcbd414d940e93",16);
027        //曲线X系数
028        private static final BigInteger gx = new BigInteger("32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7",16);
029        //曲线Y系数
030        private static final BigInteger gy = new BigInteger("bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0",16);
031        //生产者顺序系数
032        private static final BigInteger n  = new BigInteger("fffffffeffffffffffffffffffffffff7203df6b21c6052b53bbf40939d54123",16);
033        //素数
034        private static final BigInteger p  = new BigInteger("fffffffeffffffffffffffffffffffffffffffff00000000ffffffffffffffff",16);
035        //因子系数 1
036        private static final int h  = 1;
037        /*-----------------------国密算法相关参数end-----------------------------*/
038        //一些必要类
039        public final ECFieldElement ecc_gx_fieldelement;
040        public final ECFieldElement ecc_gy_fieldelement;
041        public final ECCurve ecc_curve;
042        public final ECPoint ecc_point_g;
043        public final ECDomainParameters ecc_bc_spec;
044        public final ECKeyPairGenerator ecc_key_pair_generator;
045        /**
046         * 初始化方法
047         * @return
048         */
049        public static SM2Factory getInstance(){
050                return new SM2Factory();
051        }
052        public SM2Factory() {
053
054                this.ecc_gx_fieldelement = new Fp(this.p,this.gx);
055                this.ecc_gy_fieldelement = new Fp(this.p, this.gy);
056
057                this.ecc_curve = new ECCurve.Fp(this.p, this.a, this.b);
058
059                this.ecc_point_g = new ECPoint.Fp(this.ecc_curve, this.ecc_gx_fieldelement,this.ecc_gy_fieldelement);
060                this.ecc_bc_spec = new ECDomainParameters(this.ecc_curve, this.ecc_point_g, this.n);
061
062                ECKeyGenerationParameters ecc_ecgenparam;
063                ecc_ecgenparam = new ECKeyGenerationParameters(this.ecc_bc_spec, new SecureRandom());
064
065                this.ecc_key_pair_generator = new ECKeyPairGenerator();
066                this.ecc_key_pair_generator.init(ecc_ecgenparam);
067        }
068        /**
069         * 根据私钥、曲线参数计算Z
070         * @param userId
071         * @param userKey
072         * @return
073         */
074        public  byte[] sm2GetZ(byte[] userId, ECPoint userKey){
075                SM3Digest sm3 = new SM3Digest();
076
077                int len = userId.length * 8;
078                sm3.update((byte) (len >> 8 & 0xFF));
079                sm3.update((byte) (len & 0xFF));
080                sm3.update(userId, 0, userId.length);
081
082                byte[] p = ConvertUtil.byteConvert32Bytes(this.a);
083                sm3.update(p, 0, p.length);
084
085                p = ConvertUtil.byteConvert32Bytes(this.b);
086                sm3.update(p, 0, p.length);
087
088                p = ConvertUtil.byteConvert32Bytes(this.gx);
089                sm3.update(p, 0, p.length);
090
091                p = ConvertUtil.byteConvert32Bytes(this.gy);
092                sm3.update(p, 0, p.length);
093
094                p = ConvertUtil.byteConvert32Bytes(userKey.normalize().getXCoord().toBigInteger());
095                sm3.update(p, 0, p.length);
096
097                p = ConvertUtil.byteConvert32Bytes(userKey.normalize().getYCoord().toBigInteger());
098                sm3.update(p, 0, p.length);
099
100                byte[] md = new byte[sm3.getDigestSize()];
101                sm3.doFinal(md, 0);
102                return md;
103        }
104        /**
105         * 签名相关值计算
106         * @param md
107         * @param userD
108         * @param userKey
109         * @param sm2Result
110         */
111        public void sm2Sign(byte[] md, BigInteger userD, ECPoint userKey, SM2Result sm2Result) {
112                BigInteger e = new BigInteger(1, md);
113                BigInteger k = null;
114                ECPoint kp = null;
115                BigInteger r = null;
116                BigInteger s = null;
117                do {
118                        do {
119                                // 正式环境
120                                AsymmetricCipherKeyPair keypair = ecc_key_pair_generator.generateKeyPair();
121                                ECPrivateKeyParameters ecpriv = (ECPrivateKeyParameters) keypair.getPrivate();
122                                ECPublicKeyParameters ecpub = (ECPublicKeyParameters) keypair.getPublic();
123                                k = ecpriv.getD();
124                                kp = ecpub.getQ();
125                                //System.out.println("BigInteger:" + k + "\nECPoint:" + kp);
126
127                                //System.out.println("计算曲线点X1: "+ kp.getXCoord().toBigInteger().toString(16));
128                                //System.out.println("计算曲线点Y1: "+ kp.getYCoord().toBigInteger().toString(16));
129                                //System.out.println("");
130                                // r
131                                r = e.add(kp.getXCoord().toBigInteger());
132                                r = r.mod(this.n);
133                        } while (r.equals(BigInteger.ZERO) || r.add(k).equals(this.n)||r.toString(16).length()!=64);
134
135                        // (1 + dA)~-1
136                        BigInteger da_1 = userD.add(BigInteger.ONE);
137                        da_1 = da_1.modInverse(this.n);
138                        // s
139                        s = r.multiply(userD);
140                        s = k.subtract(s).mod(this.n);
141                        s = da_1.multiply(s).mod(this.n);
142                } while (s.equals(BigInteger.ZERO)||(s.toString(16).length()!=64));
143
144                sm2Result.r = r;
145                sm2Result.s = s;
146        }
147        /**
148         * 验签
149         * @param md sm3摘要
150         * @param userKey 根据公钥decode一个ecpoint对象
151         * @param r 没有特殊含义
152         * @param s 没有特殊含义
153         * @param sm2Result 接收参数的对象
154         */
155        public void sm2Verify(byte md[], ECPoint userKey, BigInteger r,
156                          BigInteger s, SM2Result sm2Result) {
157                sm2Result.R = null;
158                BigInteger e = new BigInteger(1, md);
159                BigInteger t = r.add(s).mod(this.n);
160                if (t.equals(BigInteger.ZERO)) {
161                        return;
162                } else {
163                        ECPoint x1y1 = ecc_point_g.multiply(sm2Result.s);
164                        //System.out.println("计算曲线点X0: "+ x1y1.normalize().getXCoord().toBigInteger().toString(16));
165                        //System.out.println("计算曲线点Y0: "+ x1y1.normalize().getYCoord().toBigInteger().toString(16));
166                        //System.out.println("");
167
168                        x1y1 = x1y1.add(userKey.multiply(t));
169                        //System.out.println("计算曲线点X1: "+ x1y1.normalize().getXCoord().toBigInteger().toString(16));
170                        //System.out.println("计算曲线点Y1: "+ x1y1.normalize().getYCoord().toBigInteger().toString(16));
171                        //System.out.println("");
172                        sm2Result.R = e.add(x1y1.normalize().getXCoord().toBigInteger()).mod(this.n);
173                        //System.out.println("R: " + sm2Result.R.toString(16));
174                        return;
175                }
176        }
177
178}