全连接网络
最近在编写tensorflow的相关代码,实现了一下全连接网络实现mnist的代码,具体代码如下
相关函数说明
tf.argmax(y,1)
#返回数组y中最大值的索引,如果y是多维数组,只在当前数组中比较,按行比较
tf.argmax(y,0)
如果是0,比较当前列,也就是多个数组的相同位置,按列比较
code
import tensorflow as tf
import os
import time
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
#设置相关的显卡,可以是单独一个设置,也可以是多个设备
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
tf_config = tf.ConfigProto()
#动态申请显存
tf_config.gpu_options.allow_growth = True
mnist = input_data.read_data_sets("mnist",one_hot=True)
#适当的增加batch_size,会提高模型训练的时间
batch_size = 128
n_batch = mnist.train.num_examples // batch_size
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32, [None,10])
keep_prob = tf.placeholder(tf.float32)
w1 = tf.Variable(tf.truncated_normal([784,1024],stddev=0.1))
b1 = tf.Variable(tf.random_normal([1024]))
L1 = tf.nn.tanh(tf.matmul(x,w1) + b1)
L1_drop = tf.nn.dropout(L1,keep_prob)
w2 = tf.Variable(tf.truncated_normal([1024,1024],stddev=0.1))
b2 = tf.Variable(tf.random_normal([1024]))
L2 = tf.nn.relu(tf.matmul(L1_drop,w2) + b2)
L2_drop = tf.nn.dropout(L2,keep_prob)
w3 = tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
b3 = tf.Variable(tf.random_normal([10]))
prediction = tf.matmul(L2_drop,w3) + b3
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(loss)
init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
start = time.time()
with tf.Session(config = tf_config) as sess:
sess.run(init)
for epoch in range(50):
epoch_start = time.time()
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.75})
epoch_end = time.time()
epoch_time = epoch_end - epoch_start
test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels,keep_prob:1.0})
print("Iter:" + str(epoch) + ", Test Acc:" + str(test_acc) + ",Train Acc:" + str(train_acc) + ",Time:"+ str(epoch_time))
end = time.time()
print("Total Time:"+str(end - start) + "s")
print("success")
Trick
增加batch_size,会提高训练的速度,减少训练的时间
stable epoch指的是acc稳定到0.98所需要的epoch数量
batch_size | time | epoch | acc | stable epoch | mem |
---|---|---|---|---|---|
32 | 290s | 50 | 0.981 | 37 | 1132m |
64 | 154s | 50 | 0.983 | 28 | 1132m |
128 | 96s | 50 | 0.982 | 23 | 1132m |
256 | 67s | 50 | 0.985 | 10 | 1132m |
设置GPU设备的时候是有优先级的
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
和
os.environ["CUDA_VISIBLE_DEVICES"] = "1,0"
这两个 是不一样的,在前面的会优先使用