使用逻辑回归模型去识别tensorflow数据集的数字识别。
代码展示:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import math
def logistic_regression():
#加载数据
mnist = input_data.read_data_sets(r'C:\Users\Administrator\Desktop\AI_project\tensorflow\MNIST_data', one_hot=True)
#给数据建立placeholder占位符
batch_size = 128
x = tf.placeholder(tf.float32,[batch_size,784],name="x_data")
y = tf.placeholder(tf.int32,[batch_size,10],name="y_data")
#初始化权重
w = tf.Variable(tf.random_normal([784,10],stddev=0.1))
b = tf.Variable(tf.zeros([10]))
#计算损失,定义优化方法
val = tf.add(tf.matmul(x,w),b) #这个是矩阵128*10......看清了
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=val)) #这个是向量....看清了
train = tf.train.AdamOptimizer().minimize(loss)
#测试模型
pred = tf.nn.softmax(val)
pred = tf.equal(tf.argmax(input=pred,axis=1),tf.argmax(input=y,axis=1))
accuary = tf.reduce_mean(tf.cast(pred,tf.float32))
#初始化变量
init = tf.global_variables_initializer()
#对session进行执行定义好的运算
with tf.Session() as sess:
sess.run(init)
writer = tf.summary.FileWriter("./logs",sess.graph)
writer.close()
n_batch_train = math.ceil(mnist.train.num_examples / batch_size)
n_batch_test = math.ceil(mnist.train.num_examples / batch_size)
for i in range(50):
loss_total = 0
#训练模型
for _ in range( n_batch_train ):
x_input,y_input = mnist.train.next_batch(batch_size)
_,l = sess.run([train,loss],feed_dict={x:x_input,y:y_input})
loss_total+=l
print("Interation:{},loss:{}".format(i,loss_total/n_batch_train),end=" ")
accuary_total = 0
#测试模型
for _ in range(n_batch_test):
x_input, y_input = mnist.test.next_batch(batch_size)
accuary_total+=sess.run(accuary,feed_dict={x:x_input,y:y_input})
print("accuary:{}".format(accuary_total/n_batch_test))
结果展示:
Interation:0,loss:0.7622147449920343 accuary:0.8922420058139535
Interation:1,loss:0.3823678414142409 accuary:0.9089571220930233
Interation:2,loss:0.3328973266967507 accuary:0.9157885174418605
Interation:3,loss:0.3105759471308353 accuary:0.9191860465116279
Interation:4,loss:0.2960549441534419 accuary:0.9207485465116279
Interation:5,loss:0.2865448264535083 accuary:0.9211300872093023
Interation:6,loss:0.2805248944045499 accuary:0.9229651162790697
Interation:7,loss:0.2745572795008504 accuary:0.924781976744186
Interation:8,loss:0.2710500391243502 accuary:0.9243459302325582
Interation:9,loss:0.26706690582078557 accuary:0.9253997093023256
Interation:10,loss:0.2629715468647868 accuary:0.9269440406976744
Interation:11,loss:0.2621379436968371 accuary:0.9255450581395349
Interation:12,loss:0.2595898272512957 accuary:0.9261264534883721
Interation:13,loss:0.2563329931435197 accuary:0.9253088662790697
Interation:14,loss:0.2554247479279374 accuary:0.9272529069767442
Interation:15,loss:0.25390707252677097 accuary:0.9283793604651163
Interation:16,loss:0.2524607227985249 accuary:0.9253997093023256
Interation:17,loss:0.2505904925423999 accuary:0.9253815406976744
Interation:18,loss:0.25018050428046734 accuary:0.9276526162790698
Interation:19,loss:0.24771193742405537 accuary:0.9277979651162791
Interation:20,loss:0.24811712648979453 accuary:0.9273074127906977
Interation:21,loss:0.2459677146444487 accuary:0.9260174418604651
Interation:22,loss:0.24644057648473008 accuary:0.9283793604651163
Interation:23,loss:0.24489293940538584 accuary:0.9258539244186047
Interation:24,loss:0.24324410925077838 accuary:0.9274890988372093
Interation:25,loss:0.24299994336311206 accuary:0.9282885174418605
Interation:26,loss:0.24238698946875195 accuary:0.9269077034883721
Interation:27,loss:0.24184117951365405 accuary:0.9276162790697674
Interation:28,loss:0.2400227582038835 accuary:0.9278343023255814
Interation:29,loss:0.2400645112748756 accuary:0.9271438953488372
Interation:30,loss:0.24123970490208893 accuary:0.9277071220930233
Interation:31,loss:0.23869256044543066 accuary:0.9289244186046511
Interation:32,loss:0.23661559423388437 accuary:0.9276344476744186
Interation:33,loss:0.23885162564557652 accuary:0.9273982558139535
Interation:34,loss:0.23682482469567032 accuary:0.9284883720930233
Interation:35,loss:0.2369533678300159 accuary:0.9261809593023256
Interation:36,loss:0.23730118084092472 accuary:0.9273800872093023
Interation:37,loss:0.23646007918341216 accuary:0.9275617732558139
Interation:38,loss:0.23514422782978348 accuary:0.9273255813953488
Interation:39,loss:0.2359093708181104 accuary:0.9276707848837209
Interation:40,loss:0.23345869906073394 accuary:0.9270893895348837
Interation:41,loss:0.2352673845062422 accuary:0.9272710755813953
Interation:42,loss:0.23326533116226972 accuary:0.9262718023255814
Interation:43,loss:0.2358175446474275 accuary:0.9289607558139535
Interation:44,loss:0.23324581149012544 accuary:0.9281431686046512
Interation:45,loss:0.23273995503090147 accuary:0.9288154069767441
Interation:46,loss:0.23142784215336623 accuary:0.9275981104651163
Interation:47,loss:0.23274108893996062 accuary:0.9285428779069768
Interation:48,loss:0.23295457486149876 accuary:0.9285428779069768
Interation:49,loss:0.23111646177117215 accuary:0.9279069767441861
总结:
还是数据批处理的问题,防止内存爆掉(我亲身经历了好几次,内存99%,然后死机的!),优化函数方法的选取问题,
注意:axis的使用,损失函数返回的形式(维数),关注矩阵的维数。