深度学习-tensorflow-代码

该博客介绍了如何在TensorFlow中加载预训练模型并固定部分网络参数进行微调。首先,通过`tf.contrib.framework.get_variables_to_restore()`获取变量,并筛选出'easyflow'部分的变量进行恢复。接着,定义AdamOptimizer进行训练,同时使用控制依赖来确保更新操作在梯度下降之前执行。此外,还展示了张量操作如切片、分割、拼接和堆叠等。
摘要由CSDN通过智能技术生成

加载预训练模型

https://cloud.tencent.com/developer/article/1197031

加载部分预训练模型

variables = tf.contrib.framework.get_variables_to_restore()
variables_to_restore = [v for v in variables if v.name.split('/')[0] == 'easyflow']
saver_res = tf.train.Saver(variables_to_restore)
saver_res.restore(sess, pre_train_model)

固定部分网络参数

variables = slim.get_variables_to_restore()
variables_to_retore = [v for v in variables if v.name.split('/')[0] == 'easyflow']
variables_to_train = [v.name for v in variables if v.name.split('/')[0] == 'netflow']
print(variables_to_retore)
### Defind optimizer
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
loss_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=variables_to_train)

with tf.control_dependencies(update_ops):
    Training_step2 = tf.compat.v1.train.AdamOptimizer(lr_ori).minimize(OptimizeLoss_2, var_list=loss_vars)
saver_res = tf.compat.v1.train.Saver(variables_to_retore)
saver_res.restore(sess, pre_train_model)
varvar = sess.graph.get_tensor_by_name('easyflow/c1/weights:0')

张量操作:

tensor张量的打印,

tf.slice(input,begin,size),

tf.split(input,num_or_size_split,axis=0,num=None)

tf.concat(input,axis)

tf.stack(input,axis=0)

tf.unstack(input,num=None,axis=0)

http://chenjingjiu.cn/index.php/2019/07/05/tensorflow-matrix-op/

tf.transpose()张量维度转置

https://blog.csdn.net/qq_40994943/article/details/85270159

tf.split()把一个张量划分成几个子张量
https://blog.csdn.net/qq_31150463/article/details/84137883

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值