最近女朋友问了我一个bug,关于tensorflow的,我打算看看顺便捡起一些tensorflow的知识,毕竟毕业论文用的是pytorch。但是搞了很久也不知道为什么会有这个bug,最后她解决了,跟我说了一下怎么解决的,但是我们都不清楚为什么这样就可以解决问题,所以打算在这里记录一下。
Bug
// 这是一个class里的代码
self.x = tf.placeholder(tf.float32, [None, self.walk_length], name='walk_var')
self.optimizer, self.loss, self.H = self._build_training_graph()
self.sess = tf.Session()
init = tf.global_variables_initializer()
self.sess.run(init)
// 然后在build_graph里面调用了另一个class的函数
def _build_training_graph(self):
# 1d-CNN
H = self.model.cnn_forward(self.x, drop_prob=self.drop_prob, reuse=False)
// model.py里的cnn_forward如下
def cnn_forward(self, x, drop_prob, reuse=tf.AUTO_REUSE):
print('x shape:', np.shape(x), 'x type:', type(x))
with tf.variable_scope('cnn', reuse=reuse):
data = tf.convert_to_tensor(x)
data = tf.reshape(data, [-1, x.shape[1], 1])
shape = tf.shape(data)
stddev = 1.0 / tf.sqrt(tf.