#声明:本文创作内容含代码均为个人创作所得,允许学习、传阅,不得用于商业用途#
#本文包含国密SM3从算法到硬件实现的全部#
#读者务必认真理解代码实现过程,而不知简单的复制粘贴#
一 背景:
SM3是一种杂凑(hash)算法的计算方法,是国家密码公开算法标准,与国际上SHA-256相对应。其主要适用于商用密码应用中的数字签名和验证、消息认证码的生成与验证以及随机数的生成。
二:算法说明:
建议先直接下载“SM3密码杂凑算法”一文阅读
1. 简述
对长度为L(L<2^64)比特的消息m,经过杂凑算法填充和迭代压缩,生成一个256比特的数据。
2. 填充
填充的目的是为了保证消息比特长度(也就是0和1的总个数)是512的倍数,此举是为了满足后面迭代要求。首先将比特“1”,添加到消息的尾部,然后再添加K个比特“0”,然后再添加64个比特串,比特串是消息长度L的二进制表示,比如消息长度是24个比特,那么添加的64个比特串的二进制表示为 00…011000(十进制是24)。其中数字K满足一下等式,L+1+K+64 = 0mod 512, 也就是512的整数倍。
上图是来源于国密pdf
3. 迭代
将上面填充后的消息按照每512比特进行分组,B0,B1,B2,… , B(n-1)
迭代方式:上一轮计算结果hash值和本轮的B(i)运算得到本轮的hash值, 用V来表示。伪代码如下:
V(0)是256比特的初始值,固定且已知,B(i)就是上面的填充分组后的消息,最终结果为V(n)
那么这个迭代压缩函数具体是怎样的形式呢?且往下看
3.1 消息扩展
第一步:
也就是W0~W15就是分组消息的前16个字,然后用这16个字迭代出W16~W67,然后再用W0~W67迭代出W`0~W`63
3.2: 压缩函数CV
初始化常量IV = 7380166f 4914b2b9 172442d7 da8a0600 a96f30bc 163138aa e38dee4d b0fb0e4e,
由于编辑器打字不太方便,所以我就用截图的方式贴在这里:
说明:因为计算一组数据需要经过64次迭代,所以上面的循环变量为0~63,为了增加复杂度,在不同的迭代时间使用不同的变量或函数;
迭代过程:
迭代过程先用常量赋值给ABCDEFGH,经过64轮迭代,每轮迭代会更新中间变量TT1,TT2,SS1,SS2;64轮迭代结束后,ABCDEFGH的值就是本轮V的值,该值用于下一次迭代;
算法总结:将输入的数据先进行填充得到N个512比特的数据,然后经过N轮迭代压缩运算,每一轮的具体是先将输入的512比特划分为16个字,用这16个字经过运算获得一些中间变量值,然后用这些中间变量值以及上一次轮的hash值经过64轮的逻辑运算,就可以得到最终的hash值;
三 算法代码
因为python库和IDE比较完善,同时python也是目前很火的一种编程语言,所以我们采用python来实现。该代码均在Pycharm中实现
python是基于字符串来处理数据的,用python来进行数据位域计算是很恶心的,大家一定要注意闭坑啊,太坑了python。
不熟悉python的小伙伴一定要自己写一遍,否则这里的坑以后很难规避的
对于软件输入的数据格式有两种,一种是字符,比如‘abc’,另外一种是16进制
python代码如下:
class SM3_PRE_PROCESS:
def str2byte(self, string):
hex_array = []
string_encode = string.encode()
for char in string:
hex_array.append(int(ord(char)))
return hex_array
def num2byte(self, string):
hex_array = []
string_len = len(string)
if (string_len % 2):
string = '0' + string
#string_encode = string.encode()
for i in range(int(string_len / 2)):
hex_array.append(int(string[i * 2:i * 2 + 2], 16))
return hex_array
def padding(self, string, num):
if (num == 0):
data = self.str2byte(string)
bit_len = len(string) * 8 # each num is 8 bit
else:
data = self.num2byte(string)
bit_len = len(data) * 8 # each num is 8 bit
#bit_len = len(string) * 8 # each num is 8 bit
string_0_len = 512 - (bit_len + 1 + 64) % 512
string_0 = '1' + ('0' * string_0_len) # internal bin
string_64_bit = '0' * (64 - (len(bin(bit_len)) - 2)) + bin(bit_len)[2:] # last 64bit
padding_bit = string_0 + string_64_bit
for i in range(int(len(padding_bit) / 8)): # message to 16 word(32bit)
data.append(int(padding_bit[8 * i:8 * (i + 1)], 2))
return data
class SM3_CALCU:
IV = '7380166f4914b2b9172442d7da8a0600a96f30bc163138aae38dee4db0fb0e4e'
V0 = []
for i in range(8):
V0.append(
(int(IV[8*i:8*(i+1)], 16)) & 0xFFFFFFFF) # transform each 8byte to int, so we can calculate with int
def rotate_left(self,a,k):
k = k % 32
return ((a << k) & 0xFFFFFFFF) | ((a & 0xFFFFFFFF) >> (32-k))
def T(self, i):
if (i <= 15):
return int('79cc4519', 16) & 0xFFFFFFFF
else:
return int('7a879d8a', 16) & 0xFFFFFFFF
def FF(self, X, Y, Z, j):
if 0 <= j <= 15:
return X ^ Y ^ Z
else:
return (X & Y) | (X & Z) | (Y & Z)
def GG(self, X, Y, Z, j):
if 0 <= j <= 15:
return X ^ Y ^ Z
else:
return (X & Y) | (~X & Z)
def P0(self, X):
return X ^ self.rotate_left(X,9) ^ self.rotate_left(X,17)
def P1(self, X):
return X ^ self.rotate_left(X,15) ^ self.rotate_left(X,23)
def info_expand(self, info):
''' 16 word total 16*32= 512 bit namely 64 byte '''
W = [0] * 68
W_ = [0] * 64
word_list = [0]*16
for i in range(16):
# tmp = (hex(info[4*i]))[2:] + hex(info[4*i+1])[2:] + hex(info[4*i+2])[2:] + hex(info[4*i+3])[2:] # big hole
tmp = ('{:02x}'.format(info[4*i])) + ('{:02x}'.format(info[4*i+1])) + ('{:02x}'.format(info[4*i+2])) + ('{:02x}'.format(info[4*i+3]))
word_list[i] = int(tmp,16) & 0xFFFFFFFF
W[0:16] = word_list
for j in range(16, 68):
xx = (self.P1(W[j - 16] ^ W[j - 9] ^ (self.rotate_left(W[j - 3], 15)))) & 0xFFFFFFFF
W[j] = (self.P1(W[j - 16] ^ W[j - 9] ^ self.rotate_left(W[j - 3] ,15)) ^ self.rotate_left(W[j - 13] , 7) ^ W[j - 6]) & 0xFFFFFFFF
for j in range(64):
W_[j] = W[j] ^ W[j + 4]
return (W, W_)
def iteration(self, msg, V):
(W, W_) = self.info_expand(msg)
A, B, C, D, E, F, G, H = V
for j in range(64):
SS1 = (self.rotate_left(self.rotate_left(A, 12) + E + self.rotate_left(self.T(j), j), 7)) & 0xFFFFFFFF
SS2 = SS1 ^ self.rotate_left(A, 12)
TT1 = (self.FF(A, B, C, j) + D + SS2 + W_[j]) & 0xFFFFFFFF
TT2 = (self.GG(E, F, G, j) + H + SS1 + W[j]) & 0xFFFFFFFF
D = C
C = self.rotate_left(B ,9)
B = A
A = TT1
H = G
G = self.rotate_left(F , 19)
F = E
E = self.P0(TT2)
return [A, B, C, D, E, F, G, H]
def loop_cal(self, x):
loop_len = int(len(x) / 64) # how much 512 bit
V = [[0]] * (loop_len + 1)
V[0] = self.V0
for i in range(loop_len):
[A, B, C, D, E, F, G, H] = self.iteration(x[64*i:64*(i+1)], V[i])
tmp = V[i]
V[i+1] = [A ^ tmp[0], B ^ tmp[1], C ^ tmp[2], D ^ tmp[3], E ^ tmp[4], F ^ tmp[5], G ^ tmp[6], H ^ tmp[7]]
return V[loop_len]
if __name__ == '__main__':
get_class = SM3_PRE_PROCESS()
# data = get_class.padding('abc', 0) # message is char
# data = get_class.padding('abcd' * 16, 0) # message is char
data = get_class.padding('61626364'*16, 1)
cal_class = SM3_CALCU()
data_hash = cal_class.loop_cal(data)
for i in range(len(data_hash)):
print(hex(data_hash[i]))
官网文档里面,给出了迭代过程的中间数据,其中有个错误,截图如下:
四 Verilog硬件实现
《最近时间忙,后续稍后补充》