tensorfow模型转pytorch

参考链接:https://blog.csdn.net/weixin_42699651/article/details/88932670

一. tensorflow模型转pytorch模型

import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np

def tr(v):
    # tensorflow weights to pytorch weights
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v

def read_ckpt(ckpt):
    # https://github.com/tensorflow/tensorflow/issues/1823
    reader = tf.train.NewCheckpointReader(ckpt)
    weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
    pyweights = {k: tr(v) for (k, v) in weights.items()}
    return pyweights
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
    parser.add_argument("infile", type=str,
                        help="Path to the ckpt.")  # ***model.ckpt-22177***
    parser.add_argument("outfile", type=str, nargs='?', default='',
                        help="Output file (inferred if missing).")
    args = parser.parse_args()
    if args.outfile == '':
        args.outfile = os.path.splitext(args.infile)[0] + '.h5'
    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    weights = read_ckpt(args.infile)
    dd.io.save(args.outfile, weights)

1.运行上述代码后会得到model.h5模型,如下:
备注:保持tensorflow和pytorch使用的python版本一致

2.使用:在pytorch内加载改模型:
这里假设网络保存时参数命名一致

net = ...
import torch
import deepdish as dd
net = resnet50(..)
model_dict = net.state_dict()
#先将参数值numpy转换为tensor形式
pretrained_dict =  = dd.io.load('./model.h5')
new_pre_dict = {}
for k,v in pretrained_dict.items():
    new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
net.load_state_dict(model_dict)

 

二. 安装过程中遇到的问题如下

1.无法安装deepdish包,原因是没装HDF5

 /usr/bin/ld: 找不到 -lhdf5
    collect2: error: ld returned 1 exit status
    * Using Python 3.5.2 (default, Nov 23 2017, 16:37:01)
    * USE_PKGCONFIG: True
    .. ERROR:: Could not find a local HDF5 installation.
       You may need to explicitly state where your local HDF5 headers and
       library can be found by setting the ``HDF5_DIR`` environment
       variable or by using the ``--hdf5`` command-line option.

    ----------------------------------------
Command "python setup.py egg_info" failed with error code 1 in /tmp/pip-install-sosb3hw4/tables/

HDF5安装步骤如下(参考链接)

ubuntu版本:16.04.2   64位

从HDF官网(https://support.hdfgroup.org/HDF5/)上下载hdf5-1.8.17.tar.gz 。。。。直接在界面搜索比较快

简要安装步骤如下:(详细步骤:hdf5-1.8.17/release_docs/INSTALL)

            $ gunzip < hdf5-X.Y.Z.tar.gz | tar xf -   #解压缩
            $ cd hdf5-X.Y.Z
            $ ./configure --prefix=/usr/local/hdf5  #安装路径
            $ make
            $ make check                # run test suite.
            $ make install
            $ make check-install        # verify installation.

注:

1)X.Y.Z是HDF版本

2)建议自己新建一个文件夹作为安装路径,默认情况下。。我并没有找到后续操作需要用到的文件

3)如果提示权限不够就在前面加个 sudo 吧

安装成功后,在安装目录/usr/local下出现hdf5文件夹,打开后

在/share/hdf5_examples/下是示例程序。打开c文件夹,下面我们来测试。该文件夹下有个名为run-c-ex.sh文件,执行该文件可以将c目录下所有.c文件执行。
运行命令  $ sudo ./run-c-ex.sh

即可得到所有文件的执行结果。

安装好HDF5以后,再到终端 pip install deepdish 即可成功

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>