tf.keras与 TensorFlow混用,trainable=False设置无效

目录

 

一个简单的例子

另外一个例子:来自KerasLayer trainable=false seems to have no effect

解决的办法


一个简单的例子

import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Conv2D

input = tf.ones([5, 5, 5, 5])

with tf.variable_scope('z'):
    z = tf.Variable(tf.zeros(shape=[3,1],dtype=tf.float32),name='z',trainable=False)

with tf.variable_scope('conv'):
    output1 = Conv2D(filters=3,kernel_size=3,trainable=False)(input)
    # output = tf.keras.layers.BatchNormalization(trainable=False)(input,training=True)

print(tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES))
print(tf.trainable_variables())

打印:

[<tf.Variable 'conv/conv2d/kernel:0' shape=(3, 3, 5, 3) dtype=float32>, <tf.Variable 'conv/conv2d/bias:0' shape=(3,) dtype=float32>]
[<tf.Variable 'conv/conv2d/kernel:0' shape=(3, 3, 5, 3) dtype=float32>, <tf.Variable 'conv/conv2d/bias:0' shape=(3,) dtype=float32>]

另外一个例子:来自KerasLayer trainable=false seems to have no effect

Setting trainable=False seems to have no effect.

Minimal example:

layer = tf.keras.layers.Dense(
      units=1,
      kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
      bias_initializer=tf.keras.initializers.Constant([1.0]),
      trainable=False)
y = layer(tf.constant([[1.0]]))
with tf.Session() as session:
  session.run(tf.global_variables_initializer())
  for var in tf.global_variables():
    print("var: " + str(var) + " trainable: " + str(var.trainable))

results in:

var: <tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32> trainable: True 
var: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32> trainable: True

 

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量。

真的是,还要这么多冗余的操作去规避这个缺点,fuck

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值