报错:
KeyError:‘Transformer/encoderblock_0\MultiHeadDotProductAttention_1/query\kernel is not a file in the archive’
出错原因:
pjoin(ROOT, ATTENTION_Q, “kernel”)将多个路径组合为一个路径出错
pjoin()函数在Windows下,路径分隔符是反斜杠\,而在Unix和Linux下,路径分隔符是正斜杠/
pjoin(ROOT, ATTENTION_Q, “kernel”)在windows下输出为
'Transformer/encoderblock_0\MultiHeadDotProductAttention_1/query\kernel`
正确的加载路径应为
'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/kernel`
解决办法:
在moding.py 文件中:将下面几行代码
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
修改为:
query_weight = np2th(weights[(ROOT + '/' + ATTENTION_Q + "/kernel")]).view(self.hidden_size,self.hidden_size).t()
key_weight = np2th(weights[(ROOT + '/' + ATTENTION_K + "/kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[(ROOT + '/' + ATTENTION_V + "/kernel")]).view(self.hidden_size,self.hidden_size).t()
out_weight = np2th(weights[(ROOT + '/' + ATTENTION_OUT+ "/kernel")]).view(self.hidden_size, self.hidden_size).t()