tensorflow 交替训练

tf.compat.v1.Session

定义Session:

sess = tf.compat.v1.Session()

tf.compat.v1.get_collection

  1. 获取构建的静态图网络中的所有变量list:
variables =tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
for i in variables:
	print(i)
  1. 获取指定scope中的所有变量list:
train_var = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='output')
for i in train_var:
	print(i)

tf.compat.v1.name_scope

(1)在某个tf.compat.v1.name_scope()指定的区域中定义的所有对象及各种操作,他们的“name”属性上会增加该命名区的区域名,用以区别对象属于哪个区域;
(2)将不同的对象及操作放在由tf.name_scope()指定的区域中,便于在tensorboard中展示清晰的逻辑关系图,这点在复杂关系图中特别重要。

with tf.compat.v1.name_scope("generator"):
	sample = tf.compat.v1.Variable(tf.random.normal([self.n, self.m]), name='sample')
	......

tf.compat.v1.variable_scope

Tensorflow提供了Variable_Scope机制来共享变量。

with tf.compat.v1.variable_scope("discriminate", reuse = tf.compat.v1.AUTO_REUSE):
	wo = tf.compat.v1.get_variable(shape = [dims[i], dims[i + 1]],name='w_discriminate')
	......

tf.compat.v1.train.Optimizer

Tensorflow的Optimizer类。

opt = tf.compat.v1.train.RMSPropOptimizer(lr, decay)

opt.minimize

训练目标,通常为最小化loss。

global_step = tf.compat.v1.Variable(0, trainable=False,name = 'global_step')
train_step = opt.minimize(net.loss,global_step=global_step,var_list=train_vars)

var_list,顾名思义,即训练时,要训练的变量列表list,这可以通过tf.compat.v1.get_collection获得,并且因为list的特性,list = list1 + list2

tf.compat.v1.global_variables_initializer

全局变量初始化,Tensorflow只能使用初始化后的变量进行训练。

sess.run(tf.compat.v1.global_variables_initializer())

sess.run(train_step, feed_dict={…})

梯度更新。

sess.run(train_step, feed_dict={...})

示例代码

''' Session '''
sess = tf.compat.v1.Session()

''' Define Net '''
net = Net()

''' Set up val_list '''
first_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='first')
second_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='second')
all_vars = first_vars + second_vars

''' Set up optimizer '''
opt = tf.compat.v1.train.RMSPropOptimizer(lr, decay)

''' Minimize '''
global_first = tf.compat.v1.Variable(0, trainable=False,name = 'global_first')
train_first = opt.minimize(net.loss,global_step=global_first,var_list=first_vars)

global_second = tf.compat.v1.Variable(0, trainable=False,name = 'global_second')
train_second = opt.minimize(net.loss,global_step=global_second,var_list=second_vars)

global_all = tf.compat.v1.Variable(0, trainable=False,name = 'global_all')
train_all = opt.minimize(net.loss,global_step=global_all,var_list=all_vars)

''' Initialize TensorFlow variables '''
sess.run(tf.compat.v1.global_variables_initializer())

''' Optimize '''
sess.run(train_first, feed_dict={...})
sess.run(train_second, feed_dict={...})
sess.run(train_all, feed_dict={...})
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值