首先定义你自己的TF模型,并加载训练好的模型文件(不加载也可以)
class MyModel1:
TFmodel = MyModel()
TFmodel.load_weights('./training_checkpoint_265.h5', by_name=True)
然后定义一个PyTorch模型
(注意,这里的Pytorch模型结构必须和TF模型结构完全一样)
class MyModel2(nn.Module):
PyTorchModel = MyModel2()
然后就可以愉快的加载参数啦
import tensorflow as tf
import deepdish as dd
import numpy as np
def tr(a):#将Tensorflow的张量转换成PyTorch的张量
v = tf.convert_to_tensor(a).numpy()
# tensorflow weights to pytorch weights
if len(v.shape) == 4:
return np.ascontiguousarray(v.transpose(3,2,0,1))
elif len(v.shape) == 2:
return np.ascontiguousarray(v.transpose())
return v
TF_weights = {TFmodel.trainable_variables[i].name: TFmodel.trainable_variables[i] for i in range(0 , len(TFmodel.trainable_variables))}
model_dict = PyTorchModel.state_dict()
#这里由于我两个模型的参数名不能一一对应,所以选择这种按下标来加载的方法
trans_weights = [tr(v) for (k, v) in TF_weights.items()]
i=0
for name,param in PyTorchModel.named_parameters():
arr = trans_weights [i]
model_dict[name] = torch.Tensor(arr)
i+=1
PyTorchModel.load_state_dict(model_dict)
#如果你两个模型的参数名能够一一对应,那么可以选择这种按名字来加载的方法
trans_weights {k: tr(v) for (k, v) in TF_weights.items()}
new_pre_dict = {}
for k,v in trans_weights .items():
new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
PyTorchModel.load_state_dict(model_dict)
最后,给Pytorch的各个层起名字还是挺麻烦的,可以参照以下博客
TF1.X模型转Pytorch模型可以参见这个博客
https://blog.csdn.net/weixin_42699651/article/details/88932670