TensorFlow Pytorch 模型读取相关
Tensorflow
Tensorflow载入部分权重
基本步骤
- 为layer命名,加入name=‘layer_name’
- model.load_weights(’…h5’, by_name=True)
适用方法
- 网络中间的某些层不载入权重
model.layers[id].name = 'layer_name_2' # 重命名不想载入权重的层
model.load_weights('weights.h5', by_name = True)
- 删除了网络中的某些层(剪枝)后,想用之前训练的大网络的权重继续训练,可以直接调用load_weights方法
说明:在不同TF版本之间载入模型,也可以调用该方法
参考资料:
Tensorflow模型转Pytorch
- 将Tensorflow模型转换为.h5格式
- 在pytorch中构建一个参数命名一致的模型pyModel
- 调用如下代码
net = ...
import torch
import deepdish as dd
net = pyModel(..)
model_dict = net.state_dict()
# 先将参数值numpy转换为tensor形式
pretrained_dict = = dd.io.load('./model.h5')
new_pre_dict = {}
for k,v in pretrained_dict.items():
new_pre_dict[k] = torch.Tensor(v)
model_dict.update(new_pre_dict) # 更新
net.load_state_dict(model_dict) # 加载
参考资料:
H5文件
HDF5结构:File - Group - DataSet
说明:
- Group可以嵌套Group,提供的visit()和visititems()可以递归group。
- 可以用HDFView打开文件
参考资料: