训练模型的时候,维数一定要匹配,同时要了解你自己的数据的格式,和读取的类型,一个one_hot编码用的函数和非one_hot用的函数完全不一样,这也是我当时一直出现问题的原因。
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 11:32:40 2018
@author: huangxudong
"""
import dr_alexnet
import tensorflow as tf
import read_data2
#定义网络超参数
learning_rate=0.01
train_iters=2000
batch_size=5
capacity=256
display_step=10
#读取数据
tra_list,tra_labels,val_list,val_labels=read_data2.get_files('/home/bigvision/Desktop/DR_model',0.2)
tra_list_batch,tra_label_batch=read_data2.get_batch(tra_list,tra_labels,512,512,batch_size,capacity)
val_list_batch,val_label_batch=read_data2.get_batch(val_list,val_labels,512,512,batch_size,capacity)
#定义网络参数
n_class=6 #标记维度
dropout=0.75
skip=[]
#输入占位符
x=tf.placeholder(tf.float32,[None,786432]) #2800*2100*3,512*512*3
y=tf.placeholder(tf.int32,[None])
#print(y.shape)
keep_prob=tf.placeholder(tf.float32) #dropout
''''构建模型,定义损失函数和优化器'''''
pred=dr_alexnet.alexNet(x,dropout,n_class,skip)
#定义损失函数和优化器
cost=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=pred.fc3))
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
#评估函数,优化函数
correct_pred=tf.nn.in_top_k(pred.fc3,y,1) #1表示列上去最大,0是行,这个地方如果是one_hot就是tf.argmax
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) #改类型
'''训练模型'''
init=tf.global_variables_initializer() #初始化所有变量
with tf.Session() as sess:
sess.run(init)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
step=1
#开始训练,达到最大训练次数
while step*batch_size<train_iters:
batch_x,batch_y=tra_list_batch.eval(session=sess),tra_label_batch.eval(session=sess)
batch_x=batch_x.reshape((batch_size,786432))
batch_y=batch_y.T
sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
if step%display_step==2:
#计算损失值和准确度,输出
loss,acc=sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})
print("Iter"+str(step*batch_size)+",Minibatch Loss="+ "{:.6f}".format(loss)+", Training Acc"+ "{:.5f}".format(acc))
step+=1
print("Optimization Finished!")
coord.request_stop()
coord.join(threads) #多线程进行batch送入
feed_dict字典读取数据的时候不能是tensor类型,必须是list,numpy类型(还有一个忘了),所以在送入batch数据的时候加入了.eval(session.sess),当初这块也是磨了很久。希望以后不在犯错
本人新人,对大家有帮助的话就点赞哦