目录
另外一个例子:来自KerasLayer trainable=false seems to have no effect
一个简单的例子
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Conv2D
input = tf.ones([5, 5, 5, 5])
with tf.variable_scope('z'):
z = tf.Variable(tf.zeros(shape=[3,1],dtype=tf.float32),name='z',trainable=False)
with tf.variable_scope('conv'):
output1 = Conv2D(filters=3,kernel_size=3,trainable=False)(input)
# output = tf.keras.layers.BatchNormalization(trainable=False)(input,training=True)
print(tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES))
print(tf.trainable_variables())
打印:
[<tf.Variable 'conv/conv2d/kernel:0' shape=(3, 3, 5, 3) dtype=float32>, <tf.Variable 'conv/conv2d/bias:0' shape=(3,) dtype=float32>]
[<tf.Variable 'conv/conv2d/kernel:0' shape=(3, 3, 5, 3) dtype=float32>, <tf.Variable 'conv/conv2d/bias:0' shape=(3,) dtype=float32>]
另外一个例子:来自KerasLayer trainable=false seems to have no effect
Setting trainable=False seems to have no effect.
Minimal example:
layer = tf.keras.layers.Dense(
units=1,
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
bias_initializer=tf.keras.initializers.Constant([1.0]),
trainable=False)
y = layer(tf.constant([[1.0]]))
with tf.Session() as session:
session.run(tf.global_variables_initializer())
for var in tf.global_variables():
print("var: " + str(var) + " trainable: " + str(var.trainable))
results in:
var: <tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32> trainable: True
var: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32> trainable: True
解决的办法
解决的办法就是在导入模型的时候建立一个variable_scope
,将需要训练的变量放在另一个variable_scope
,然后通过tf.get_collection
获取需要训练的变量,最后通过tf的优化器中var_list
指定需要训练的变量。
真的是,还要这么多冗余的操作去规避这个缺点,fuck