实现思路
只有加解密两个函数,抽象出一个接口,使用门面和策略模式,对外开放加解密接口,由于密码比较场景使用率较高,再对外开放一个比较的接口,接口最少有两个参数,明文密码,密文密码,详见代码:
type ICrypto interface {
Decrypt(dest *string, arg ...interface{}) (*string, error)
Encrypt(dest *string, arg ...interface{}) (*string, error)
Compare(deStr, enStr *string, arg ...interface{}) bool
}
arg
为可变参数,用于调用方传递秘钥。
创建一个结构体,用于存储公钥私钥等信息:
type commCrypto struct {
rsaPri *rsa.PrivateKey
rsaPub *rsa.PublicKey
rsaPubStr *string
aesKey *string
}
talk is chip, show me the code.
crypto.go
完整代码
package crypto
import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"config"
"io/ioutil"
"log"
"sync"
)
func init() {
Setup()
}
func Setup() {
if nil != myCrypto {
return
}
locker.Lock()
defer locker.Unlock()
if nil != myCrypto {
return
}
log.Println("load crypto start------------------------")
myCrypto = &commCrypto{}
var pri, pub []byte
var err error
loadRsa:
if config.CommonConfig.Rsa != nil &&
len(config.CommonConfig.Rsa.PrivateKeyFilePath) > 0 &&
len(config.CommonConfig.Rsa.PublicKeyFilePath) > 0 {
log.Printf("load rsa private file: %s \n", config.CommonConfig.Rsa.PrivateKeyFilePath)
// 读取文件
pri, err = ioutil.ReadFile(config.CommonConfig.Rsa.PrivateKeyFilePath)
if nil != err {
log.Printf("load rsa private file err: %s\n", err.Error())
config.CommonConfig.Rsa = nil
goto loadRsa
}
log.Printf("load rsa public file: %s\n", config.CommonConfig.Rsa.PublicKeyFilePath)
// 读取文件
pri, err = ioutil.ReadFile(config.CommonConfig.Rsa.PublicKeyFilePath)
if err != nil {
log.Printf("load rsa public file err: %s\n", err.Error())
config.CommonConfig.Rsa = nil
goto loadRsa
}
} else {
log.Printf("use default rsa config")
pri = []byte(defaultPrivateKeyStr)
pub = []byte(defaultPublicKeyStr)
}
if privateKey, err := readRSAPrivateKey(pri); err == nil {
myCrypto.rsaPri = privateKey
} else {
log.Printf("read rsa private err: %s\n", err.Error())
if config.CommonConfig.Rsa != nil &&
len(config.CommonConfig.Rsa.PrivateKeyFilePath) > 0 &&
len(config.CommonConfig.Rsa.PublicKeyFilePath) > 0 {
config.CommonConfig.Rsa = nil
goto loadRsa
}
goto loadAes
}
if publicKey, publicKeyStr, err := readRSAPublicKey(pub); err == nil {
myCrypto.rsaPub = publicKey
myCrypto.rsaPubStr = publicKeyStr
} else {
log.Printf("read rsa public err: %s\n", err.Error())
if config.CommonConfig.Rsa != nil &&
len(config.CommonConfig.Rsa.PrivateKeyFilePath) > 0 &&
len(config.CommonConfig.Rsa.PublicKeyFilePath) > 0 {
config.CommonConfig.Rsa = nil
goto loadRsa
}
goto loadAes
}
loadAes:
if config.CommonConfig.Aes == nil || len(config.CommonConfig.Aes.Key) <= 0 {
log.Printf("use default aes config")
myCrypto.aesKey = &defaultAesKeyStr
} else {
log.Printf("load aes config")
keyLenMap := map[int]struct{}{16: {}, 24: {}, 32: {}}
if _, ok := keyLenMap[len(config.CommonConfig.Aes.Key)]; !ok {
log.Printf("key长度必须是 16、24、32 其中一个")
config.CommonConfig.Aes = nil
goto loadAes
}
myCrypto.aesKey = &config.CommonConfig.Aes.Key
}
log.Printf("load crypto end------------------------")
}
var myCrypto *commCrypto
var locker sync.Mutex
var cryptoMap = make(map[cryptoType]ICrypto, 0)
type cryptoType int
type ICrypto interface {
Decrypt(dest *string, arg ...interface{}) (*string, error)
Encrypt(dest *string, arg ...interface{}) (*string, error)
Compare(deStr, enStr *string, arg ...interface{}) bool
}
type commCrypto struct {
rsaPri *rsa.PrivateKey
rsaPub *rsa.PublicKey
rsaPubStr *string
aesKey *string
}
func Decrypt(dest *string, arg ...interface{}) (*string, error) {
var t cryptoType
if arg == nil {
arg = make([]interface{}, 0)
}
if len(arg) > 0 {
if val, ok := arg[0].(cryptoType); !ok {
t = defAes
} else {
t = val
arg = arg[1:]
}
} else {
t = defAes
}
return cryptoMap[t].Decrypt(dest, arg...)
}
func Encrypt(dest *string, arg ...interface{}) (*string, error) {
var t cryptoType
if arg == nil {
arg = make([]interface{}, 0)
}
if len(arg) > 0 {
if val, ok := arg[0].(cryptoType); !ok {
t = defAes
} else {
t = val
arg = arg[1:]
}
} else {
t = defAes
}
return cryptoMap[t].Encrypt(dest, arg...)
}
// Compare 默认AES
func Compare(deStr, enStr *string, arg ...interface{}) bool {
var t cryptoType
if nil == deStr || nil == enStr {
return false
}
if deStr == enStr {
return false
}
if arg == nil {
arg = make([]interface{}, 0)
}
if len(arg) > 0 {
if val, ok := arg[0].(cryptoType); !ok {
t = defAes
} else {
t = val
arg = arg[1:]
}
} else {
t = defAes
}
return cryptoMap[t].Compare(deStr, enStr, arg...)
}
// readRSAPrivateKey 读取私钥(包含PKCS1和PKCS8)
func readRSAPrivateKey(pri []byte) (*rsa.PrivateKey, error) {
var err error
// pem解码
pemBlock, _ := pem.Decode(pri)
var pkixPrivateKey interface{}
if pemBlock.Type == "RSA PRIVATE KEY" {
pkixPrivateKey, err = x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
} else if pemBlock.Type == "PRIVATE KEY" {
pkixPrivateKey, err = x509.ParsePKCS8PrivateKey(pemBlock.Bytes)
}
if err != nil {
return nil, err
}
return pkixPrivateKey.(*rsa.PrivateKey), err
}
// readRSAPublicKey 读取公钥(包含PKCS1和PKCS8)
func readRSAPublicKey(pub []byte) (*rsa.PublicKey, *string, error) {
var err error
// 使用pem解码
pemBlock, _ := pem.Decode(pub)
var pkixPublicKey interface{}
if pemBlock.Type == "RSA PUBLIC KEY" {
// -----BEGIN RSA PUBLIC KEY-----
pkixPublicKey, err = x509.ParsePKCS1PublicKey(pemBlock.Bytes)
} else if pemBlock.Type == "PUBLIC KEY" {
// -----BEGIN PUBLIC KEY-----
pkixPublicKey, err = x509.ParsePKIXPublicKey(pemBlock.Bytes)
}
if err != nil {
return nil, nil, err
}
publicKeyPem := pem.EncodeToMemory(pemBlock)
publicKeyStr := base64.StdEncoding.EncodeToString(publicKeyPem)
publicKey := pkixPublicKey.(*rsa.PublicKey)
return publicKey, &publicKeyStr, nil
}
func GetRSAPublicKey() *string {
if myCrypto == nil {
Setup()
}
return myCrypto.rsaPubStr
}
const (
RSA cryptoType = iota
AES_CBC
defAes = AES_CBC
)
var (
//defaultPrivateKeyStr
//generate PKCS1 openssl genrsa -out private_ssl.pem 1024
//generate PKCS8 openssl pkcs8 -topk8 -inform pem -in private_ssl.pem -outform PEM -nocrypt
defaultPrivateKeyStr = `-----BEGIN PRIVATE KEY-----
MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBALarNy0rEC1qwy8P
9KuH6Ir3Ny4iyzOb+mWjzD6jaSOyzHLeKYqw4o0F4qrz78Yd2vZOjwj2mbN8xG2e
tnJKAykYTAMi6RysXDlT8O3BLRKJctLE7jmKvsc1gbLY2/0rx2yhAW77eiOJ+oZ+
R4Is7oouvwvGgEmBAy+dpZqrsJ0bAgMBAAECgYBIUMX6Orcf08lLo1xwX9Ce2znc
KOgbGV/qxwq/rX+dI1avDuaRQm+d9ruChCnjW8RoiDc2DDJTDUzSPOfrnQNcoS9a
3o54x86bHhzpeHyuIS5b6m1xDt6yUN4rzKv1MUT6vD85jJ4/sWrhUMH6q5rDA0Td
/KZrKBzdjvSgcZY4iQJBAN/4gOcZbFi7q8ZWqO02vtY9RLQAqD/dfTK7T2u3TMoK
Ly/XBfV46XGYIQuf8cWoSArQvIhF9LiM67/pNO88UgcCQQDQyqmTCE1tPPc922U1
JqPDMF+xdgOMsg3G5aP6XNLnfrPnQS9zbp6dK00rL0zv0ekAV3rKgNzTCgl0gxsC
1EdNAkEArTfGifVRpHbQ7T6Mu5nBknQkNIrllS856wiO7iH/06p4wCkuxKDU+zPL
KvByzonN3f9+MG7aS/lBQ1WbyQL/9QJAW47zq8FxSpz4gsyp5hPqrlaRMB3jUphm
CDl9nfWEmvVp6Ngh+cmhjqSFc5GLeIMhXP//nbvCml0FZm1zs9ro5QJBAI0eCzS4
CK0hBSiGRPHdfV7A8mPBjUip1MxI6cYP6JJkB40t3SNzRHWMG/ulmko36XkZxYpm
QCjSLUYx14GYr/U=
-----END PRIVATE KEY-----`
//defaultPublicKeyStr
//generate PKCS8 openssl rsa -in private_ssl.pem -pubout -out public_ssl.pem
//-----END PUBLIC KEY-----`
defaultPublicKeyStr = `-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC2qzctKxAtasMvD/Srh+iK9zcu
Isszm/plo8w+o2kjssxy3imKsOKNBeKq8+/GHdr2To8I9pmzfMRtnrZySgMpGEwD
IukcrFw5U/DtwS0SiXLSxO45ir7HNYGy2Nv9K8dsoQFu+3ojifqGfkeCLO6KLr8L
xoBJgQMvnaWaq7CdGwIDAQAB
-----END PUBLIC KEY-----`
defaultAesKeyStr = `k+J6AyS/+adJU1/J`
)
config用于对外开放的配置文件,配置文件的格式如下:
common:
rsa:
publicKeyFilePath: ""
privateKeyFilePath: ""
aes:
key: ""
RSA
RSA私钥有PKCS1和PKCS8两种格式,通过openssl生成的私钥格式为PKCS1,公钥格式为PKCS8。
私钥PKCS1转PKCS8,按需即可。
此秘钥为演示可用,可不要用于生产环境。
~/root# openssl genrsa -out private_ssl.pem 1024
Generating RSA private key, 1024 bit long modulus (2 primes)
...............+++++
.................................+++++
e is 65537 (0x010001)
~/root# ls
README.md inscode.nix private_ssl.pem
~/Markdown# cat private_ssl.pem
-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQC2qzctKxAtasMvD/Srh+iK9zcuIsszm/plo8w+o2kjssxy3imK
sOKNBeKq8+/GHdr2To8I9pmzfMRtnrZySgMpGEwDIukcrFw5U/DtwS0SiXLSxO45
ir7HNYGy2Nv9K8dsoQFu+3ojifqGfkeCLO6KLr8LxoBJgQMvnaWaq7CdGwIDAQAB
AoGASFDF+jq3H9PJS6NccF/Qnts53CjoGxlf6scKv61/nSNWrw7mkUJvnfa7goQp
41vEaIg3NgwyUw1M0jzn650DXKEvWt6OeMfOmx4c6Xh8riEuW+ptcQ7eslDeK8yr
9TFE+rw/OYyeP7Fq4VDB+quawwNE3fymaygc3Y70oHGWOIkCQQDf+IDnGWxYu6vG
VqjtNr7WPUS0AKg/3X0yu09rt0zKCi8v1wX1eOlxmCELn/HFqEgK0LyIRfS4jOu/
6TTvPFIHAkEA0MqpkwhNbTz3PdtlNSajwzBfsXYDjLINxuWj+lzS536z50Evc26e
nStNKy9M79HpAFd6yoDc0woJdIMbAtRHTQJBAK03xon1UaR20O0+jLuZwZJ0JDSK
5ZUvOesIju4h/9OqeMApLsSg1Pszyyrwcs6Jzd3/fjBu2kv5QUNVm8kC//UCQFuO
86vBcUqc+ILMqeYT6q5WkTAd41KYZgg5fZ31hJr1aejYIfnJoY6khXORi3iDIVz/
/527wppdBWZtc7Pa6OUCQQCNHgs0uAitIQUohkTx3X1ewPJjwY1IqdTMSOnGD+iS
ZAeNLd0jc0R1jBv7pZpKN+l5GcWKZkAo0i1GMdeBmK/1
-----END RSA PRIVATE KEY-----
~/Markdown# openssl rsa -in private_ssl.pem -pubout -out public_ssl.pem
writing RSA key
~/Markdown# cat public_ssl.pem
-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC2qzctKxAtasMvD/Srh+iK9zcu
Isszm/plo8w+o2kjssxy3imKsOKNBeKq8+/GHdr2To8I9pmzfMRtnrZySgMpGEwD
IukcrFw5U/DtwS0SiXLSxO45ir7HNYGy2Nv9K8dsoQFu+3ojifqGfkeCLO6KLr8L
xoBJgQMvnaWaq7CdGwIDAQAB
-----END PUBLIC KEY-----
~/root# openssl pkcs8 -topk8 -inform pem -in private_ssl.pem -outform PEM -nocrypt
-----BEGIN PRIVATE KEY-----
MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBALarNy0rEC1qwy8P
9KuH6Ir3Ny4iyzOb+mWjzD6jaSOyzHLeKYqw4o0F4qrz78Yd2vZOjwj2mbN8xG2e
tnJKAykYTAMi6RysXDlT8O3BLRKJctLE7jmKvsc1gbLY2/0rx2yhAW77eiOJ+oZ+
R4Is7oouvwvGgEmBAy+dpZqrsJ0bAgMBAAECgYBIUMX6Orcf08lLo1xwX9Ce2znc
KOgbGV/qxwq/rX+dI1avDuaRQm+d9ruChCnjW8RoiDc2DDJTDUzSPOfrnQNcoS9a
3o54x86bHhzpeHyuIS5b6m1xDt6yUN4rzKv1MUT6vD85jJ4/sWrhUMH6q5rDA0Td
/KZrKBzdjvSgcZY4iQJBAN/4gOcZbFi7q8ZWqO02vtY9RLQAqD/dfTK7T2u3TMoK
Ly/XBfV46XGYIQuf8cWoSArQvIhF9LiM67/pNO88UgcCQQDQyqmTCE1tPPc922U1
JqPDMF+xdgOMsg3G5aP6XNLnfrPnQS9zbp6dK00rL0zv0ekAV3rKgNzTCgl0gxsC
1EdNAkEArTfGifVRpHbQ7T6Mu5nBknQkNIrllS856wiO7iH/06p4wCkuxKDU+zPL
KvByzonN3f9+MG7aS/lBQ1WbyQL/9QJAW47zq8FxSpz4gsyp5hPqrlaRMB3jUphm
CDl9nfWEmvVp6Ngh+cmhjqSFc5GLeIMhXP//nbvCml0FZm1zs9ro5QJBAI0eCzS4
CK0hBSiGRPHdfV7A8mPBjUip1MxI6cYP6JJkB40t3SNzRHWMG/ulmko36XkZxYpm
QCjSLUYx14GYr/U=
-----END PRIVATE KEY-----
~/root#
src_crypto.go
代码如下:
package crypto
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"errors"
log "logger"
)
type rsaCrypto struct {
}
var rsaSingleton = new(*rsaCrypto)
func init() {
log.Infof("init RSA crypto")
cryptoMap[RSA] = *rsaSingleton
}
func (r *rsaCrypto) Decrypt(dest *string, arg ...interface{}) (*string, error) {
decodeString, err := base64.StdEncoding.DecodeString(*dest)
var pri *rsa.PrivateKey
if arg != nil && len(arg) > 0 {
if key, ok := arg[0].(*rsa.PrivateKey); ok {
pri = key
} else {
return nil, errors.New("arg invalid")
}
}
if nil == pri {
pri = myCrypto.rsaPri
}
decryptPKCS1v15, err := rsa.DecryptPKCS1v15(rand.Reader, pri, decodeString)
if err != nil {
return nil, err
}
deStr := string(decryptPKCS1v15)
return &deStr, err
}
func (r *rsaCrypto) Encrypt(dest *string, arg ...interface{}) (*string, error) {
var pub *rsa.PublicKey
if arg != nil && len(arg) > 0 {
if key, ok := arg[0].(*rsa.PublicKey); ok {
pub = key
}
}
if nil == pub {
pub = myCrypto.rsaPub
}
encryptPKCS1v15, err := rsa.EncryptPKCS1v15(rand.Reader, pub, []byte(*dest))
if nil != err {
return nil, err
}
encryptString := base64.StdEncoding.EncodeToString(encryptPKCS1v15)
return &encryptString, err
}
func (r *rsaCrypto) Compare(deStr, enStr *string, arg ...interface{}) bool {
decrypt, err := r.Decrypt(enStr, arg...)
if err != nil {
log.Info("compare rsa err: ", err.Error())
return false
}
if *decrypt == *deStr {
return true
}
return false
}
AES
AES加解密有五种方式,以下方式为CBC
模式
aes_crypto.go
代码如下:
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
log "logger"
)
type aesCBCCrypto struct {
}
var aesCBCSingleton = new(*aesCBCCrypto)
func init() {
log.Infof("init AES crypto")
cryptoMap[AES_CBC] = *aesCBCSingleton
}
func (r *aesCBCCrypto) Decrypt(dest *string, arg ...interface{}) (*string, error) {
var key *string
if arg != nil && len(arg) > 0 {
if k, ok := arg[0].(*string); ok {
key = k
}
}
if nil == key {
key = myCrypto.aesKey
if len(*key) <= 0 {
return nil, errors.New("aesKey not found")
}
}
decodeString, _ := base64.StdEncoding.DecodeString(*dest)
keyByte := []byte(*key)
block, err := aes.NewCipher(keyByte)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, keyByte[:blockSize])
decodeResult := make([]byte, blockSize)
blockMode.CryptBlocks(decodeResult, decodeString)
padding := doPKCS7UNPadding(decodeResult)
deStr := string(padding)
return &deStr, err
}
func (r *aesCBCCrypto) Encrypt(dest *string, arg ...interface{}) (*string, error) {
var key *string
if arg != nil && len(arg) > 0 {
if k, ok := arg[0].(*string); ok {
key = k
}
}
if nil == key {
key = myCrypto.aesKey
if len(*key) <= 0 {
return nil, errors.New("aesKey not found")
}
}
originDataByte := []byte(*dest)
keyByte := []byte(*key)
block, err := aes.NewCipher(keyByte)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
originDataByte = doPKCS7Padding(originDataByte, blockSize)
blockMode := cipher.NewCBCEncrypter(block, keyByte[:blockSize])
encrypted := make([]byte, len(originDataByte))
blockMode.CryptBlocks(encrypted, originDataByte)
enStr := base64.StdEncoding.EncodeToString(encrypted)
return &enStr, nil
}
func (r *aesCBCCrypto) Compare(deStr, enStr *string, arg ...interface{}) bool {
encrypt, err := r.Encrypt(deStr, arg...)
if err != nil {
log.Errorf("compare rsa err: ", err.Error())
return false
}
if *encrypt == *enStr {
return true
}
return false
}
// doPKCS7Padding 补码
func doPKCS7Padding(originByte []byte, blockSize int) []byte {
// 计算补码长度
padding := blockSize - len(originByte)%blockSize
// 生成补码
padText := bytes.Repeat([]byte{byte(padding)}, padding)
// 追加补码
return append(originByte, padText...)
}
// doPKCS7UNPadding 解码
func doPKCS7UNPadding(originDataByte []byte) []byte {
length := len(originDataByte)
unpadding := int(originDataByte[length-1])
return originDataByte[:(length - unpadding)]
}
测试用例
crypto_test.go
package crypto
import "testing"
func TestRSA(t *testing.T) {
dest := "test"
encrypt, err := Encrypt(&dest, RSA)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
t.Logf("encrypt: %s", *encrypt)
decrypt, err := Decrypt(encrypt, RSA)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
if *decrypt != dest {
t.Errorf("Expected %s, but got %s", dest, *decrypt)
} else {
t.Logf("dncrypt: %s", *decrypt)
}
}
func TestAES(t *testing.T) {
dest := "test"
encrypt, err := Encrypt(&dest)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
t.Logf("encrypt: %s", *encrypt)
decrypt, err := Decrypt(encrypt)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
if *decrypt != dest {
t.Errorf("Expected %s, but got %s", dest, *decrypt)
} else {
t.Logf("dncrypt: %s", *decrypt)
}
}
func TestCompare(t *testing.T) {
deStr := "test"
encrypt, err := Encrypt(&deStr)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
compare := Compare(&deStr, encrypt)
t.Logf("AES Encrypt AES compare: %v", compare)
compare = Compare(&deStr, encrypt, RSA)
t.Logf("AES Encrypt RSA compare: %v", compare)
encrypt, err = Encrypt(&deStr, RSA)
if err != nil {
t.Errorf("Unexpected error: %s", err.Error())
return
}
compare = Compare(&deStr, encrypt)
t.Logf("RSA Encrypt AES compare: %v", compare)
compare = Compare(&deStr, encrypt, RSA)
t.Logf("RSA Encrypt RSA compare: %v", compare)
}