利用CNN进行图片简单6分类,数据集为6中车型网上爬取的,这里进行一系列数据预处理后,进行CNN卷积。
数据集部分展示
代码展示
#encoding = utf-8
"""
@author:syj
@file:img_分类.py
@time:2019/09/27 14:05:47
"""
#导库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
# 使用GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
# 随机种子
tf.set_random_seed(77)
# 数据集路径
cat_dir = r'C:\Users\Administrator\Desktop\car_datas\车1'
#数据读取转换为矩阵 标签分类
def load_data(name_class):
num = 0 #数据集总数
images_data = [] #样本
labels_data = [] #标签
# 循环读取
for i in name_class:
for k in os.listdir((cat_dir + '/' + i)): #得到图片名字 房车_0.jpg
img = plt.imread(cat_dir + '/' + i + '/' + k) #plt可以读取中文
img = cv2.resize(img, (64,64)) #所有图片转化为64*64*3
img_array = np.array(img) #转化为数组
img_array = img_array / 127.5 - 1 #归一化 -1到1
images_data.append(img_array) #添加到列表
# 分类
if k[:2] == '卡车':
labels_data.append(0)
elif k[:2] == '房车':
labels_data.append(1)
elif k[:2] == '摩托':
labels_data.append(2)
elif k[:2] == '自行':
labels_data.append(3)
elif k[:2] == '越野':
labels_data.append(4)
else:
labels_data.append(5)
num += 1 #数据集总数
img_array = np.array(images_data)
lab_array = np.array(labels_data)
return img_array,lab_array,num
name_class = os.listdir(cat_dir) #路径
print(name_class)
num_class = len(name_class)
# 洗牌
def shuffle_data(imgage_data,labels_data,num):
p = np.random.permutation(num)
imgage_data = imgage_data[p]
labels_data = labels_data[p]
return imgage_data,labels_data
# 调用
imgage_data,labels_data,num = load_data(name_class)
imgage_data,labels_data = shuffle_data(imgage_data,labels_data,num)
print(imgage_data.shape)
print(labels_data.shape)
# 切分
train_x,test_x,train_y,test_y = train_test_split(imgage_data,labels_data,test_size=0.2,random_state=7)
# 站位
x = tf.placeholder(tf.float32,[None,64,64,3])
y = tf.placeholder(tf.int64,[None])
# 失活 全连接防止过拟合
keep_prob = tf.placeholder(tf.float32)
# 根据批次切分
x_image_arr = tf.split(x,num_or_size_splits=100,axis=0)
result_x_image_arr = []
# 循环读取优化数据
for x_single_image in x_image_arr:
x_single_image = tf.reshape(x_single_image,[64,64,3])
#随机翻转
data_aug_1 = tf.image.random_flip_left_right(x_single_image)
#调整光照
data_aug_2 = tf.image.random_brightness(data_aug_1,max_delta=63)
#改变对比度
data_aug_3 = tf.image.random_contrast(data_aug_2,lower=0.2,upper=1.8)
#白化
data_aug_4 = tf.image.per_image_standardization(data_aug_3)
x_single_image = tf.reshape(data_aug_4,[1,64,64,3])
result_x_image_arr.append(x_single_image)
result_x_images = tf.concat(result_x_image_arr,axis=0)
# 全连接
conv1 = tf.layers.conv2d(result_x_images,32,(3,3),padding='same',activation=tf.nn.relu)
conv1 = tf.layers.batch_normalization(conv1,momentum=0.7) #防止过拟合
pooling1 = tf.layers.max_pooling2d(conv1,(2,2),(2,2))
conv2 = tf.layers.conv2d(pooling1,64,(3,3),padding='same',activation=tf.nn.relu)
conv2 = tf.layers.batch_normalization(conv2,momentum=0.7)
pooling2 = tf.layers.max_pooling2d(conv2,(2,2),(2,2))
conv3 = tf.layers.conv2d(pooling2,128,(3,3),padding='same',activation=tf.nn.relu)
conv3 = tf.layers.batch_normalization(conv3,momentum=0.7)
pooling3 = tf.layers.max_pooling2d(conv3,(2,2),(2,2))
flatten = tf.layers.flatten(pooling3)
# 全连接
fc = tf.layers.dense(flatten,625,activation=tf.nn.tanh)
fc = tf.nn.dropout(fc,keep_prob=keep_prob)
a5 = tf.layers.dense(fc,6)
# 代价
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=a5)
# 优化器
optimizer = tf.train.AdamOptimizer(0.00005).minimize(cost)
# 准确率
pre = tf.argmax(a5,1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(pre,y),tf.float32))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 利用批次循环训练
step = 0
for i in range(1,3001):
c,a,_ = sess.run([cost,accuracy,optimizer],feed_dict={x:train_x[step:step+100],y:train_y[step:step+100],keep_prob:0.7})
step += 100
if step >= train_x.shape[0]:
step = 0
if i % 500 == 0:
print(i,np.mean(c),a)
step1 = 0
all_acc = []
for i in range(5):
a1 = sess.run(accuracy,feed_dict={x:test_x[step1:step1+100],y:test_y[step1:step1+100],keep_prob:1})
step1 += 100
all_acc.append(a1)
print(np.mean(all_acc))
效果展示
['房车', '自行车图片', '跑车', '越野车', '摩托车', '卡车']
(2752, 64, 64, 3)
(2752,)
300 0.8921075 0.67
600 0.6069706 0.81
900 0.30461997 0.92
1200 0.3142417 0.93
1500 0.16324146 0.98
1800 0.08101442 0.99
2100 0.08600599 0.99
2400 0.040265616 1.0
2700 0.035595465 1.0
3000 0.016259683 1.0
0.764
刚入手代码精度还在调,后期持续更新