RSA加密小示例

1. 前言

网上有很多关于RSA的介绍,大神阮一峰 http://www.ruanyifeng.com/blog/2013/06/rsa_algorithm_part_one.html 都写了相关的博客。 为了引出 HTTPS, 编写一个小示例。
需要注意的点:
欧几里得算法:辗转相除求最大公约数。 如果最大公约数为1, 则两数互素:
    /**
     * 欧几里得定理
     * 辗转相除求最大公约数 可用于判断互质
     * @param a 数1
     * @param b 数2
     * @return  最大公约数
     */
    public static int max(int a, int b) {
        if (a > b) {
            int tmp = a;
            a = b;
            b = tmp;
        }
        int r;
        return ( (r = b%a) == 0) ? a : max(a, r);
    }
扩展的欧几里得算法:可以求最大公约数以及 乘法逆元
    /**
     * 使用 扩展的欧几里得算法 求乘法逆元
     * ax + ny = 1
     * y = -k
     * @return x : a 对 n 的乘法逆元
     */
    private static int extendGcd(int a, int n) {
        int x2=0, x3=n, y2=1, y3=a, q, t2, t3;
        while (true) {
            if (y3 == 0)
                return 0;
            if (y3 == 1) {
                return y2 < 0 ? y2+n : y2;
            }
            q = x3 / y3;
            t2 = x2 - q * y2;
            t3 = x3 - q * y3;
            x2 = y2;
            x3 = y3;
            y2 = t2;
            y3 = t3;
        }
    }

当然要说这个欧拉定理有什么用处~~~自然是用到的人才知道好,用不到的人只知道他可以简化幂运算。 
质数 p, q
 ψ(N) = ψ(p * q) = ψ(p) * ψ(q) = (p-1)*(q-1);
任何合数都可以分解为质数的乘积的形式。


2. 只适用于本次展示的代码片段

只适用于小于128且大于0的数的编码:
import java.math.BigInteger;

public class RSA {

    /**
     * 欧几里得定理
     * 辗转相除求最大公约数 可用于判断互质
     * @param a 数1
     * @param b 数2
     * @return  最大公约数
     */
    public static int max(int a, int b) {
        if (a > b) {
            int tmp = a;
            a = b;
            b = tmp;
        }
        int r;
        return ( (r = b%a) == 0) ? a : max(a, r);
    }

    // 选取两个大素数  p, q
    // 注意别取的太小 否则求余的时候会很不精确
    // 毕竟只是演示, 只对byte起作用, 乘积大于128即可.
    // 显然, 只支持正数
    private static int p = 11;
    private static int q = 13;
    // 得到 N = p * q
    private static BigInteger N = BigInteger.valueOf(p*q);
    // 则 ψ(N) = ψ(p * q) = ψ(p) * ψ(q) = (p-1)*(q-1)  (欧拉定理)
    private static int r = (p-1)*(q-1);
    // 取任意一与 ψ(N) 互质的小于 ψ(N)的数.
    private static int e = 97;
    // 得到这个数关于 ψ(N) 的乘法逆元
    // 此时的公钥对为(e, N), 私钥对为(d, N). 其余所有数据销毁
    private static int d = extendGcd(e, r);

    /**
     * 使用 扩展的欧几里得算法 求乘法逆元
     * ax + ny = 1
     * y = -k
     * @return x : a 对 n 的乘法逆元
     */
    private static int extendGcd(int a, int n) {
        int x2=0, x3=n, y2=1, y3=a, q, t2, t3;
        while (true) {
            if (y3 == 0)
                return 0;
            if (y3 == 1) {
                return y2 < 0 ? y2+n : y2;
            }
            q = x3 / y3;
            t2 = x2 - q * y2;
            t3 = x3 - q * y3;
            x2 = y2;
            x3 = y3;
            y2 = t2;
            y3 = t3;
        }
    }

    /**
     * 因为是N的取值是在整数范围内的, 此示例使用int型作为返回方便查看
     *
     * 比如 in=2; type=3;
     * return in^type % N = 2^3 % (97*101) = 8
     *
     * @param in        被编码的值
     * @param type      秘钥: 公钥/私钥
     * @return          编码后的值
     *
     */
    private static int code(int in, int type) {
        return BigInteger.valueOf(in).pow(type).mod(N).intValue();
    }

    /**
     * 只支持正数
     * @param res   被编码数组
     * @param type  编码类型 公/私钥匙
     * @return  编码后的数组
     */
    private static byte[] code(byte[] res, int type) {
        byte[] b = new byte[res.length];
        for (int i = 0; i < res.length; i ++) {
            b[i] = (byte) code(res[i], type);
        }
        return b;
    }

    public static byte[] rsaEncode(byte[] res) {
        return code(res, e);
    }

    public static byte[] rsaDecode(byte[] res) {
        return code(res, d);
    }

    public static void main(String[] args) {

        int needs = 100;
        int needsEncode = code(needs, e);
        System.out.println("needs="+needs+", encode="+needsEncode);
        int needsDecode = code(needsEncode, d);
        System.out.println("needs="+needs+", decode="+needsDecode);

        String code = "hello";
        byte[] encode = rsaEncode(code.getBytes());
        System.out.println(new String(encode));       // 不知所云
        byte[] decode = rsaDecode(encode);
        System.out.println(new String(decode));       // 还原

    }
}

如果希望将其扩大到整个Integer范围内, 可以稍作修改实现。 注意 byte 的强转

3. JAVA中比较通用的RSA写法

import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;

import javax.crypto.Cipher;

/**
 * RSA:
 * 罗纳德·李维斯特(Ron [R]ivest)、阿迪·萨莫尔(Adi [S]hamir)和伦纳德·阿德曼(Leonard [A]dleman)
 * <p/>
 * 字符串格式的密钥在未在特殊说明情况下都为BASE64编码格式<br/>
 * 由于非对称加密速度极其缓慢,一般文件不使用它来加密而是使用对称加密,<br/>
 * 非对称加密算法可以用来对对称加密的密钥加密,这样保证密钥的安全也就保证了数据的安全
 * <p/>
 * 部分摘录
 *
 * @see http://blog.csdn.net/keda8997110/article/details/16823361
 */
public class RSAUtils {

    /**
     * 加密算法RSA
     */
    public static final String KEY_ALGORITHM = "RSA";

    /**
     * RSA最大加密明文大小
     */
    private static final int MAX_ENCRYPT_BLOCK = 117;

    /**
     * RSA最大解密密文大小
     */
    private static final int MAX_DECRYPT_BLOCK = 128;

    private static RSAPublicKey publicKey;
    private static RSAPrivateKey privateKey;


    static {
        KeyPairGenerator keyPairGen = null;
        try {
            keyPairGen = KeyPairGenerator.getInstance(KEY_ALGORITHM);
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        keyPairGen.initialize(1024);
        KeyPair keyPair = keyPairGen.generateKeyPair();
        publicKey = (RSAPublicKey) keyPair.getPublic();
        privateKey = (RSAPrivateKey) keyPair.getPrivate();
    }

    /**
     * 私钥解密
     *
     * @param encryptedData 已加密数据
     * @return
     * @throws Exception
     */
    public static byte[] decode(byte[] encryptedData) throws Exception {
        KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
        Key privateK = privateKey;
        Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
        cipher.init(Cipher.DECRYPT_MODE, privateK);
        int inputLen = encryptedData.length;
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int offSet = 0;
        byte[] cache;
        // 对数据分段解密
        while (inputLen - offSet > 0) {
            if (inputLen - offSet > MAX_DECRYPT_BLOCK) {
                cache = cipher.doFinal(encryptedData, offSet, MAX_DECRYPT_BLOCK);
            } else {
                cache = cipher.doFinal(encryptedData, offSet, inputLen - offSet);
            }
            out.write(cache, 0, cache.length);
            offSet += MAX_DECRYPT_BLOCK;
        }
        byte[] decryptedData = out.toByteArray();
        out.close();
        return decryptedData;
    }

    /**
     * 公钥加密
     *
     * @param data 源数据
     * @return
     * @throws Exception
     */
    public static byte[] encode(byte[] data)
            throws Exception {
        KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
        Key publicK = publicKey;
        // 对数据加密
        Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
        cipher.init(Cipher.ENCRYPT_MODE, publicK);
        int inputLen = data.length;
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int offSet = 0;
        byte[] cache;
        // 对数据分段加密
        while (inputLen - offSet > 0) {
            if (inputLen - offSet > MAX_ENCRYPT_BLOCK) {
                cache = cipher.doFinal(data, offSet, MAX_ENCRYPT_BLOCK);
            } else {
                cache = cipher.doFinal(data, offSet, inputLen - offSet);
            }
            out.write(cache, 0, cache.length);
            offSet += MAX_ENCRYPT_BLOCK;
        }
        byte[] encryptedData = out.toByteArray();
        out.close();
        return encryptedData;
    }

    public static void main(String[] args) throws Exception {
        String source = "china中国";
        byte[] encodedData = RSAUtils.encode(source.getBytes());

        System.out.println("encode:\t" + new String(encodedData));      // 不知所云
        byte[] decodedData = RSAUtils.decode(encodedData);
        System.out.println("decode: \t" + new String(decodedData));      // 成功解码
    }

}
这里面也可以微微看出java对字符编码的一些小问题。 有的时候前面的 "encode:\t"是显示不出来的
造成这样的原因是: jdk将byte转化为char[] 是委托给  Charsets.jar里面的各种charset来完成的。

UTF_8.java:
        public int decode(byte[] sa, int sp, int len, char[] da) {
            final int sl = sp + len;
            int dp = 0;
            int dlASCII = Math.min(len, da.length);
            ByteBuffer bb = null;  // only necessary if malformed

            // ASCII only optimized loop
            while (dp < dlASCII && sa[sp] >= 0)
                da[dp++] = (char) sa[sp++];

            while (sp < sl) {
                int b1 = sa[sp++];
                if (b1 >= 0) {
                    // 1 byte, 7 bits: 0xxxxxxx
                    da[dp++] = (char) b1;
                } else if ((b1 >> 5) == -2) {
                    // 2 bytes, 11 bits: 110xxxxx 10xxxxxx
                    if (sp < sl) {
                        int b2 = sa[sp++];
                        if (isMalformed2(b1, b2)) {
                            if (malformedInputAction() != CodingErrorAction.REPLACE)
                                return -1;
                            da[dp++] = replacement().charAt(0);
                            sp--;            // malformedN(bb, 2) always returns 1
                        } else {
                            da[dp++] = (char) (((b1 << 6) ^ b2)^
                                           (((byte) 0xC0 << 6) ^
                                            ((byte) 0x80 << 0)));
                        }
                        continue;
                    }
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    return dp;
                } else if ((b1 >> 4) == -2) {
                    // 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx
                    if (sp + 1 < sl) {
                        int b2 = sa[sp++];
                        int b3 = sa[sp++];
                        if (isMalformed3(b1, b2, b3)) {
                            if (malformedInputAction() != CodingErrorAction.REPLACE)
                                return -1;
                            da[dp++] = replacement().charAt(0);
                            sp -=3;
                            bb = getByteBuffer(bb, sa, sp);
                            sp += malformedN(bb, 3).length();
                        } else {
                            da[dp++] = (char)((b1 << 12) ^
                                              (b2 <<  6) ^
                                              (b3 ^
                                              (((byte) 0xE0 << 12) ^
                                              ((byte) 0x80 <<  6) ^
                                              ((byte) 0x80 <<  0))));
                        }
                        continue;
                    }
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    return dp;
                } else if ((b1 >> 3) == -2) {
                    // 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
                    if (sp + 2 < sl) {
                        int b2 = sa[sp++];
                        int b3 = sa[sp++];
                        int b4 = sa[sp++];
                        int uc = ((b1 << 18) ^
                                  (b2 << 12) ^
                                  (b3 <<  6) ^
                                  (b4 ^
                                   (((byte) 0xF0 << 18) ^
                                   ((byte) 0x80 << 12) ^
                                   ((byte) 0x80 <<  6) ^
                                   ((byte) 0x80 <<  0))));
                        if (isMalformed4(b2, b3, b4) ||
                            // shortest form check
                            !Character.isSupplementaryCodePoint(uc)) {
                            if (malformedInputAction() != CodingErrorAction.REPLACE)
                                return -1;
                            da[dp++] = replacement().charAt(0);
                            sp -= 4;
                            bb = getByteBuffer(bb, sa, sp);
                            sp += malformedN(bb, 4).length();
                        } else {
                            da[dp++] = Character.highSurrogate(uc);
                            da[dp++] = Character.lowSurrogate(uc);
                        }
                        continue;
                    }
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    return dp;
                } else {
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    sp--;
                    bb = getByteBuffer(bb, sa, sp);
                    CoderResult cr = malformedN(bb, 1);
                    if (!cr.isError()) {
                        // leading byte for 5 or 6-byte, but don't have enough
                        // bytes in buffer to check. Consumed rest as malformed.
                        return dp;
                    }
                    sp +=  cr.length();
                }
            }
            return dp;
        }
    }
其中期作用的地方就在将byte转为char那个while循环里面。
而char的编码,就必须遵从于Unicode的标准了。
因此看起来就会觉得怪怪的

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值