参考链接:https://blog.csdn.net/weixin_42699651/article/details/88932670
一. tensorflow模型转pytorch模型
import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np
def tr(v):
# tensorflow weights to pytorch weights
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3,2,0,1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def read_ckpt(ckpt):
# https://github.com/tensorflow/tensorflow/issues/1823
reader = tf.train.NewCheckpointReader(ckpt)
weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
pyweights = {k: tr(v) for (k, v) in weights.items()}
return pyweights
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
parser.add_argument("infile", type=str,
help="Path to the ckpt.") # ***model.ckpt-22177***
parser.add_argument("outfile", type=str, nargs='?', default='',
help="Output file (inferred if missing).")
args = parser.parse_args()
if args.outfile == '':
args.outfile = os.path.splitext(args.infile)[0] + '.h5'
outdir = os.path.dirname(args.outfile)
if not os.path.exists(outdir):
os.makedirs(outdir)
weights = read_ckpt(args.infile)
dd.io.save(args.outfile, weights)
1.运行上述代码后会得到model.h5模型,如下:
备注:保持tensorflow和pytorch使用的python版本一致
2.使用:在pytorch内加载改模型:
这里假设网络保存时参数命名一致
net = ...
import torch
import deepdish as dd
net = resnet50(..)
model_dict = net.state_dict()
#先将参数值numpy转换为tensor形式
pretrained_dict = = dd.io.load('./model.h5')
new_pre_dict = {}
for k,v in pretrained_dict.items():
new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
net.load_state_dict(model_dict)
二. 安装过程中遇到的问题如下
1.无法安装deepdish包,原因是没装HDF5
/usr/bin/ld: 找不到 -lhdf5
collect2: error: ld returned 1 exit status
* Using Python 3.5.2 (default, Nov 23 2017, 16:37:01)
* USE_PKGCONFIG: True
.. ERROR:: Could not find a local HDF5 installation.
You may need to explicitly state where your local HDF5 headers and
library can be found by setting the ``HDF5_DIR`` environment
variable or by using the ``--hdf5`` command-line option.
----------------------------------------
Command "python setup.py egg_info" failed with error code 1 in /tmp/pip-install-sosb3hw4/tables/
HDF5安装步骤如下(参考链接)
ubuntu版本:16.04.2 64位
从HDF官网(https://support.hdfgroup.org/HDF5/)上下载hdf5-1.8.17.tar.gz 。。。。直接在界面搜索比较快
简要安装步骤如下:(详细步骤:hdf5-1.8.17/release_docs/INSTALL)
$ gunzip < hdf5-X.Y.Z.tar.gz | tar xf - #解压缩
$ cd hdf5-X.Y.Z
$ ./configure --prefix=/usr/local/hdf5 #安装路径
$ make
$ make check # run test suite.
$ make install
$ make check-install # verify installation.
注:
1)X.Y.Z是HDF版本
2)建议自己新建一个文件夹作为安装路径,默认情况下。。我并没有找到后续操作需要用到的文件
3)如果提示权限不够就在前面加个 sudo 吧
安装成功后,在安装目录/usr/local下出现hdf5文件夹,打开后
在/share/hdf5_examples/下是示例程序。打开c文件夹,下面我们来测试。该文件夹下有个名为run-c-ex.sh文件,执行该文件可以将c目录下所有.c文件执行。
运行命令 $ sudo ./run-c-ex.sh
即可得到所有文件的执行结果。
安装好HDF5以后,再到终端 pip install deepdish 即可成功