cannot import name 'ReparseException' & tf模型参数提取

Tensorboard ImportError: cannot import name ‘ReparseException’


环境:

  • python 3.5.4
  • Tensorflow==1.4.1
  • html5lib==1.0.1

错误信息:

ImportError: cannot import name 'ReparseException'

定位咋:

~/anaconda3/envs/tf3/lib/python3.5/site-packages/bleach/sanitizer.py

找到对应的文件,在

from html5lib.constants import (
    entities,
    ReparseException,
    namespaces,
    prefixes,
    tokenTypes,
)

中找到出问题的地方,将ReparseException改为_ReparseException即可解决问题。

tf模型参数提取


通过TensorFlow saver 保存的模型是有v1和v2两种格式,在tf.train.Saver中有所说明:

  • write_version: controls what format to use when saving checkpoints. It also affects certain filepath matching logic. The V2 format is the recommended choice: it is much more optimized than V1 in terms of memory required and latency incurred during restore. Regardless of this flag, the Saver is able to restore from both V2 and V1 checkpoints.

实际上通过对比,v1 500M的模型到了v2需要 1.5G。。

在实际的参数提取过程中,如何把模型中的variable变成可直接操作的numpy数组是一个比较麻烦的问题。搜索后大部分答案是先构建图,然后在运行的session中进行eval。(难道就不能直接通过模型拿到参数吗?)

针对v2的模型,可以使用tensorpack包进行快速的参数提取,得到npz的数据格式。

安装

tensorpack
首先需要有git,然后命令行输入:

 pip install -U git+https://github.com/ppwwyyxx/tensorpack.git   

脚本

在仓库中的./scripts/dump-model-params.py为模型参数提取的脚本文件。

使用示例为:

./scripts/dump-model-params.py --meta train_log/alexnet-dorefa/graphxxxx.meta train_log/alexnet-dorefa/model-xxx output.npz

源码

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import numpy as np
import six
import argparse
import os
import tensorflow as tf

from tensorpack.tfutils import varmanip
from tensorpack.tfutils.common import get_op_tensor_name

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
    parser.add_argument('--meta', help='metagraph file', required=True)
    parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
    parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
    args = parser.parse_args()

    # this script does not need GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    tf.train.import_meta_graph(args.meta, clear_devices=True)

    # loading...
    if args.input.endswith('.npz'):
        dic = np.load(args.input)
    else:
        dic = varmanip.load_chkpt_vars(args.input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    # save variables that are GLOBAL, and either TRAINABLE or MODEL
    var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
    assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
    globvarname = [k.name for k in tf.global_variables()]
    var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])

    for name in var_to_dump:
        assert name in dic, "Variable {} not found in the model!".format(name)

    dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
    varmanip.save_chkpt_vars(dic_to_dump, args.output)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值