在训练GAN网络时,提示以下报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 7500 values, but the requested shape requires a multiple of 27
网络部分代码如下:
def G(self):
with tf.name_scope("Gen") as sc:
output1 = self.fully_con(self.y, 25, sc + "_1")
output2 = self.fully_con(output1, 100, sc + "_2")
output3 = self.fully_con(output2, 500, sc + "_3")
output4 = self.fully_con(output3, 100, sc + "_4")
output5 = self.fully_con(output4, self.shape_2[-1] * 25, sc + "_5")
return tf.reshape(output5, [-1, PATCH_SIZE, PATCH_SIZE, self.shape_2[-1]])
def A(self):
with tf.name_scope("App") as sc:
output1 = self.fully_con(self.x, 25, sc + "_1")
output2 = self.fully_con(output1, 100, sc + "_2")
output3 = self.fully_con(output2, 500, sc + "_3")
output4 = self.fully_con(output3, 100, sc + "_4")
output5 = self.fully_con(output4, self.shape_2[-1] * 25, sc + "_5")
return tf.reshape(output5, [-1, PATCH_SIZE, PATCH_SIZE, self.shape_2[-1]])
def D(self):
with tf.name_scope("Dis") as sc:
self.d = tf.concat([self.x, self.G], 0)
output1 = self.fully_con(self.d, 25, sc + "_1")
output2 = self.fully_con(output1, 100, sc + "_2")
output3 = self.fully_con(output2, 200, sc + "_3")
output4 = self.fully_con(output3, 50, sc + "_4")
output5 = self.fully_con(output4, 1, sc + "_5", tf.nn.sigmoid)
self.p1, self.p2 = tf.split(output5, 2, 0)
分析是由于输入reshape的Tensor维度与所需维度不一致而导致的,因此检查reshape的输入与输出维度,逐步往前推进,找到问题根源。
我这里是因为之前定义了PATCH_SIZE=3
,而G、A、D的输入维度均为25导致的,应该改为PATCH_SIZE*PATCH_SIZE
改正以后就没有报错了,至此,解决了以上问题。