将TF2.0 Tensorflow2.X 模型参数 转换成PyTorch模型的参数

首先定义你自己的TF模型,并加载训练好的模型文件(不加载也可以)

class MyModel1:

TFmodel = MyModel()
TFmodel.load_weights('./training_checkpoint_265.h5', by_name=True)

 

然后定义一个PyTorch模型

(注意,这里的Pytorch模型结构必须和TF模型结构完全一样)

class MyModel2(nn.Module):

PyTorchModel = MyModel2()

然后就可以愉快的加载参数啦

import tensorflow as tf
import deepdish as dd
import numpy as np



def tr(a):#将Tensorflow的张量转换成PyTorch的张量

    v = tf.convert_to_tensor(a).numpy()
    # tensorflow weights to pytorch weights
    if len(v.shape) == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif len(v.shape) == 2:
        return np.ascontiguousarray(v.transpose())
    return v

TF_weights = {TFmodel.trainable_variables[i].name: TFmodel.trainable_variables[i] for i in range(0 , len(TFmodel.trainable_variables))}

model_dict = PyTorchModel.state_dict()

#这里由于我两个模型的参数名不能一一对应,所以选择这种按下标来加载的方法
trans_weights = [tr(v) for (k, v) in TF_weights.items()]
i=0
for name,param in PyTorchModel.named_parameters():
    arr = trans_weights [i]
    model_dict[name] = torch.Tensor(arr)
    i+=1
PyTorchModel.load_state_dict(model_dict)




#如果你两个模型的参数名能够一一对应,那么可以选择这种按名字来加载的方法
trans_weights {k: tr(v) for (k, v) in TF_weights.items()}

new_pre_dict = {}
for k,v in trans_weights .items():
    new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
PyTorchModel.load_state_dict(model_dict)


最后,给Pytorch的各个层起名字还是挺麻烦的,可以参照以下博客

https://stackoverflow.com/questions/66152766/how-to-assign-a-name-for-a-pytorch-layer/66162559#66162559

 

TF1.X模型转Pytorch模型可以参见这个博客

https://blog.csdn.net/weixin_42699651/article/details/88932670

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值