数据集地址:
链接:#####https://pan.baidu.com/s/1ZXVb7M5p0JtS1edYRCJG2w
#####提取码:1xhz
代码如下:
#coding=utf-8
import os
#图像读取库
from PIL import Image
#矩阵运算库
import numpy as np
import tensorflow as tf
#数据集文件夹
train_dir = r"D:/Modeling_2/train/"
test_dir = r"D:/Modeling_2/test/"
#是否训练
train = True
#模型文件路径
model_path = "D:/Modeling_2/models/image_model"
def read_data(train_dir):
file_paths = []
datas = []
labels = []
# 读取 dir路径下的 文件的名称
for file_name in os.listdir(train_dir):
# 将该路径下的文件 及其路径一同打印
file_path = os.path.join(train_dir, file_name)
file_paths.append(file_path)
# 返回(图片类型JPEG,大小32*32,RGB)
image = Image.open(file_path)
# 归一化处理
data = np.array(image) / 255.0
# 取出标签
label = int(file_name.split("_")[0])
datas.append(data)
labels.append(label)
datas = np.array(datas)
labels = np.array(labels)
print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
return file_paths, datas, labels
#定义权重
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
#定义偏置
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
#卷积层
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")
#池化层 2x2池化层 步长为2
def max_2x2_pool(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")
if __name__ == "__main__":
#返回文件路径,数据,标签
file_paths, datas, labels = read_data(train_dir)
# 分类数
num_classes = len(set(labels))
# 数据与标签
x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.int32, [None])
# 5x5x3 32个卷积核
W_conv1 = weight_variable([5, 5, 3, 32])
b_conv1 = bias_variable([32])
# relu函数,卷积,池化
h_conv1 = tf.nn.relu(conv2d(x, W_conv1)+ b_conv1)
h_pool1 = max_2x2_pool(h_conv1)
# 4x4x32 64个卷积核
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2)+ b_conv2)
h_pool2 = max_2x2_pool(h_conv2)
# 5x5x3 32个卷积核
W_conv3 = weight_variable([5, 5, 64, 128])
b_conv3 = bias_variable([128])
# relu函数,卷积,池化
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = max_2x2_pool(h_conv3)
# 4x4x32 64个卷积核
W_conv4 = weight_variable([5, 5, 128, 256])
b_conv4 = bias_variable([256])
h_conv4 = tf.nn.relu(conv2d(h_pool3, W_conv4) + b_conv4)
h_pool4 = max_2x2_pool(h_conv4)
# 4次卷积池化过后 32/2/2/2/2 = 2 2*2*256
# 通过计算得到 神经元个数 2*2*256
# 扁平化处理
h_pool2_flat = tf.reshape(h_pool4, [-1, 2 * 2 * 256])
W_fc1 = weight_variable([2 * 2 * 256, num_classes])
b_fc1 = bias_variable([num_classes])
h_fc1 = tf.matmul(h_pool2_flat, W_fc1) + b_fc1
# dropout优化
keep_prob = tf.placeholder(tf.float32)
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob)
logits = h_fc1
# 返回最大值的标签
prediction_label = tf.argmax(logits, 1)
# 交叉熵函数 tf.one_hot(输入(一维),深度)
losses = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(y, num_classes), logits=logits)
mean_loss =tf.reduce_mean(losses)
# 最速梯度下降法
train_step = tf.train.AdamOptimizer(1e-4).minimize(losses)
#交叉熵代价函数
correct_predition = tf.equal(tf.argmax(tf.one_hot(y, num_classes), -1), tf.argmax(logits, -1))
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32))
Saver = tf.train.Saver()
with tf.Session() as sess:
if train:
print("训练模式!")
sess.run(tf.global_variables_initializer())
train_feed_dict = {x:datas, y:labels, keep_prob:0.5}
for step in range(1001):
_, mean_loss_val = sess.run([train_step, mean_loss], feed_dict=train_feed_dict)
if step%20 ==0:
print("step = {}\tmean loss = {}".format(step, mean_loss_val))
train_acc = sess.run(accuracy, feed_dict={x:datas, y:labels})
print("准确率:", train_acc)
Saver.save(sess, model_path)
print("训练结束,保存模型到{}".format(model_path))
else:
print("测试模式")
Saver.restore(sess, model_path)
print("从{}载入模型".format(model_path))
label_name_dict = {
0:"百合花",
1:"白玉兰",
2:"茉莉花",
3:"栀子花"
}
test_feed_dict ={x:datas, y:labels, keep_prob:0}
prediction_val = sess.run(prediction_label, feed_dict=test_feed_dict)
for file_paths, real_label, predicted_label in zip(file_paths, labels, prediction_val):
# 将label id转换为label名
real_label_name = label_name_dict[real_label]
predicted_label_name = label_name_dict[predicted_label]
print("{}\t{} => {}".format(file_paths, real_label_name, predicted_label_name))