tensorflow从与训练网络模型中fine-tune部分网络层参数

方法1:
1,开启tf的训练,有ckpt生成时停止,使用以下语句获得相关层变量的全称:

var_names=tf.contrib.framework.list_variables("/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/model.ckpt-0")

2,手工制作一个ckpt文件:挨个对上一步中的变量赋值,然后tf.saver….保存下来这个新的ckpt,代替掉上一步的ckpt,并且修改checkpoint这个文件里的路径

    with tf.Session() as sess:
        i=0
        for var_name, _ in tf.contrib.framework.list_variables("/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/model.ckpt-0"):
            # Load the variable
            i += 1
            # if i < 10:
            if var_name.startswith('learn/se'):
                if not (var_name.endswith('Adam') or var_name.endswith('Adam_1')):
                    value_npz = ckpt2npz_name(var_name) # translate var_name from npz to ckpt to get corresponding value
                    var = tf.Variable(weights[value_npz], name=var_name)

        # Save the variables
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        checkpoint_path = "/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/from_npz/model.ckpt-0"
        # print checkpoint.model_checkpoint_path
        saver.save(sess, checkpoint_path)

3,指定restore_var_list,保证tf不至于因为提供的参数过少报错,并且打印出fine-tune的变量值以确保正确:

restore_var_list = []
        for var in tf.global_variables():
            if ('learn/se' in var.name) and ('Adam' not in var.name):
                restore_var_list.append(var)
        with agent.create_session(config=config, save_dir=args.logdir, restore_var_list=restore_var_list) as sess:
            all_vars = tf.global_variables()
            with open(args.logdir + "/weight_fine_tuned.txt", "w") as f:
                for var in all_vars:
                    f.write("{}\n".format(var.name))
                    var_value = sess.run(var)
                    f.write("{}\n\n".format(var_value))

方法2:
若有下载好的ckpt文件,从ckpt文件中直接fine-tune

the .ckpt file is the old version output of saver.save(sess), which is the equivalent of your .ckpt-data (see below)

the “checkpoint” file is only here to tell some TF functions which is the latest checkpoint file.

.ckpt-meta contains the metagraph, i.e. the structure of your computation graph, without the values of the variables (basically what you can see in tensorboard/graph).

.ckpt-data contains the values for all the variables, without the structure. To restore a model in python, you’ll usually use the meta and data files with (but you can also use the .pb file):

saver = tf.train.import_meta_graph(path_to_ckpt_meta)
saver.restore(sess, path_to_ckpt_data)
I don’t know exactly for .ckpt-index, I guess it’s some kind of index needed internally to map the two previous files correctly. Anyway it’s not really necessary usually, you can restore a model with only .ckpt-meta and .ckpt-data.

the .pb file can save your whole graph (meta + data). To load and use (but not train) a graph in c++ you’ll usually use it, created with freeze_graph, which creates the .pb file from the meta and data. Be careful, (at least in previous TF versions and for some people) the py function provided by freeze_graph did not work properly, so you’d have to use the script version. Tensorflow also provides a tf.train.Saver.to_proto() method, but I don’t know what it does exactly.

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值