报错:
Shape must be rank 0 but is rank 1 for 'GradientDescent/update_layer1/weights/ApplyGradientDescent' (op: 'ApplyGradientDescent') with input shapes: [784,500], [1], [784,500].
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
BATCH_SIZE = 100
BASE_LEARNING_RATE = 0.001
DECAY = 0.999
TRAIN_STEPS = 30000
def train(mnist):
x = tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x')
y_gt = tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y_gt')
regularizer = tf.contrib.layers.l1_l2_regularizer(1.0,1.0)
pred = mnist_inference.inference(x,regularizer)
# 下面这句shape=(1),简直就是草泥马;更不能写成shape=(0)
step = tf.get_variable('step',shape=(1),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
#print(step)
# 应该写成
#step = tf.get_variable('step',shape=(),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
#或:
#step = tf.Variable(0, trainable=False)
#print(step)
ema = tf.train.ExponentialMovingAverage(DECAY,step,name='ema')
avg_op = ema.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred,labels=tf.argmax(y_gt,1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
tf.add_to_collection(tf.GraphKeys.LOSSES, cross_entropy_mean)
loss = tf.add_n(tf.get_collection(tf.GraphKeys.LOSSES))
print(loss)
print(tf.trainable_variables())
print(tf.get_collection(tf.GraphKeys.LOSSES))
learning_rate = tf.train.exponential_decay(BASE_LEARNING_RATE,step,mnist.train.num_examples/BATCH_SIZE,decay_rate=0.99,staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss,global_step=step)
with tf.control_dependencies([train_step,avg_op]):
train_op = tf.no_op(name='train_op')
# train_op = tf.group([train_step,avg_op])
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(TRAIN_STEPS):
xs,ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op,feed_dict={x:xs, y_gt:ys})
if i % 1000 == 0:
saver.save(sess,'/home/lg/Desktop/learn/mnist/model/model.ckpt')
def main():
mnist = input_data.read_data_sets('/home/lg/Desktop/learn/MNIST_data',one_hot=True)
train(mnist)
if __name__ == '__main__':
main()
TensorFlow中标量和数组的区别:
import tensorflow as tf
a = tf.get_variable('a',shape=(1),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
b = tf.get_variable('b',shape=(1,),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
c = tf.get_variable('c',shape=(),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
d = tf.get_variable('d',shape=(1,2),dtype=tf.int32,initializer=tf.constant_initializer(0),trainable=False)
print(a)
print(b)
print(c)
print(d)
输出:
<tf.Variable 'a:0' shape=(1,) dtype=int32_ref> # 一维数组,有一个元素
<tf.Variable 'b:0' shape=(1,) dtype=int32_ref> # 一维数组,有一个元素
<tf.Variable 'c:0' shape=() dtype=int32_ref> # 标量,标量。标量
<tf.Variable 'd:0' shape=(1, 2) dtype=int32_ref> # 二维数组,1行2列