|
- package main
-
- import (
- "crypto/aes"
- "crypto/cipher"
- "crypto/rand"
- "encoding/base64"
- "fmt"
- )
-
- //Encode convert a xml sequence into encrypted message
- func Encode(s string) string {
-
- d := aesEncryptMsg(random16Byte(), []byte(s), APIConfig.Appid, getAesEncryptKey())
- r := base64.StdEncoding.EncodeToString(d)
- //fmt.Println(r)
- return r
- }
-
- //Decode Decode encrypt string to xml context
- func Decode(s string) string {
-
- r, _ := base64.StdEncoding.DecodeString(s)
-
- _, raw, err := aesDecryptMsg([]byte(r), APIConfig.Appid, getAesEncryptKey())
- if err == nil {
- return string(raw)
- }
- return ""
- }
-
- // 把整数 n 格式化成 4 字节的网络字节序
- func encodeNetworkBytesOrder(orderBytes []byte, n int) {
- if len(orderBytes) != 4 {
- panic("the length of orderBytes must be equal to 4")
- }
- orderBytes[0] = byte(n >> 24)
- orderBytes[1] = byte(n >> 16)
- orderBytes[2] = byte(n >> 8)
- orderBytes[3] = byte(n)
- }
-
- // 从 4 字节的网络字节序里解析出整数
- func decodeNetworkBytesOrder(orderBytes []byte) (n int) {
- if len(orderBytes) != 4 {
- panic("the length of orderBytes must be equal to 4")
- }
- n = int(orderBytes[0])<<24 |
- int(orderBytes[1])<<16 |
- int(orderBytes[2])<<8 |
- int(orderBytes[3])
- return
- }
-
- func random16Byte() []byte {
- token := make([]byte, 16)
- rand.Read(token)
- return token
- }
-
- //AESEncryptMsg given an xml message and 16 bytes random string
- //encryptedMsg = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + AppId]
- func aesEncryptMsg(random, rawXMLMsg []byte, AppId string, AESKey [32]byte) (encryptedMsg []byte) {
- const BLOCK_SIZE = 32 // PKCS#7
-
- buf := make([]byte, 20+len(rawXMLMsg)+len(AppId)+BLOCK_SIZE)
- plain := buf[:20]
- pad := buf[len(buf)-BLOCK_SIZE:]
-
- // 拼接
- copy(plain, random)
- encodeNetworkBytesOrder(plain[16:20], len(rawXMLMsg))
- plain = append(plain, rawXMLMsg...)
- plain = append(plain, AppId...)
-
- // PKCS#7 补位
- amountToPad := BLOCK_SIZE - len(plain)%BLOCK_SIZE
- for i := 0; i < amountToPad; i++ {
- pad[i] = byte(amountToPad)
- }
- plain = buf[:len(plain)+amountToPad]
-
- // 加密
- block, err := aes.NewCipher(AESKey[:])
- if err != nil {
- panic(err)
- }
- mode := cipher.NewCBCEncrypter(block, AESKey[:16])
- mode.CryptBlocks(plain, plain)
-
- encryptedMsg = plain
- return
- }
-
- //AESDecryptMsg given a string decode it into three parts
- // encryptedMsg = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + AppId]
- func aesDecryptMsg(encryptedMsg []byte, AppId string, AESKey [32]byte) (random, rawXMLMsg []byte, err error) {
- const BLOCK_SIZE = 32 // PKCS#7
-
- if len(encryptedMsg) < BLOCK_SIZE {
- err = fmt.Errorf("the length of encryptedMsg too short: %d", len(encryptedMsg))
- return
- }
- if len(encryptedMsg)%BLOCK_SIZE != 0 {
- err = fmt.Errorf("encryptedMsg is not a multiple of the block size, the length is %d", len(encryptedMsg))
- return
- }
-
- plain := make([]byte, len(encryptedMsg)) // len(plain) >= BLOCK_SIZE
-
- // 解密
- block, err := aes.NewCipher(AESKey[:])
- if err != nil {
- panic(err)
- }
- mode := cipher.NewCBCDecrypter(block, AESKey[:16])
- mode.CryptBlocks(plain, encryptedMsg)
-
- // PKCS#7 去除补位
- amountToPad := int(plain[len(plain)-1])
- if amountToPad < 1 || amountToPad > BLOCK_SIZE {
- err = fmt.Errorf("the amount to pad is invalid: %d", amountToPad)
- return
- }
- plain = plain[:len(plain)-amountToPad]
-
- // 反拼装
- // len(plain) == 16+4+len(rawXMLMsg)+len(AppId)
- // len(AppId) > 0
- if len(plain) <= 20 {
- err = fmt.Errorf("plain msg too short, the length is %d", len(plain))
- return
- }
- msgLen := decodeNetworkBytesOrder(plain[16:20])
- if msgLen < 0 {
- err = fmt.Errorf("invalid msg length: %d", msgLen)
- return
- }
- msgEnd := 20 + msgLen
- if len(plain) <= msgEnd {
- err = fmt.Errorf("msg length too large: %d", msgLen)
- return
- }
-
- AppIdHave := string(plain[msgEnd:])
- if AppIdHave != AppId { // crypto/subtle.ConstantTimeCompare ???
- err = fmt.Errorf("AppId mismatch, have: %s, want: %s", AppIdHave, AppId)
- return
- }
-
- random = plain[:16:16]
- rawXMLMsg = plain[20:msgEnd]
- return
- }
|