作为大三学期的机器视觉这门课的课程设计,之前给自己挖下的坑,含着泪也要给补上
环境要求:anaconda tensorflow-1.12.0 opencv-4.1.2
这次的任务是使用cnn神经网络对火腿肠的外包装进行缺陷检测,这里需要将火腿肠外包装情况使用cnn给分成三类,第一类是包装正常,第二类是包装缝合处破损,第三类是变形,本来应该还有一类是封口铁丝缺失的检测的,但拍照出来resize后连肉眼都难以区分,就没有加进去,这三类图片分别如下
首先是样本数据的采集,我使用手机摄像头在固定的高度下对火腿肠进行30帧每秒的录像,然后从视频里提取了每类大约两千多张的样本图片,后将其resize成200*200像素大小,代码如下:
import numpy as np
import cv2 as cv
import os
cap = cv.VideoCapture("./video/malformation3.mp4")
frame_num=0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
if cv.waitKey(1) == ord('q'):
break
cropped=frame[100:600,200:700]
resize_pic = cv.resize(cropped, (200, 200))
cv.imshow('resize',resize_pic)
cv.imwrite('./test/malformation'+str(frame_num)+'_test.png',resize_pic)
frame_num=frame_num+1
cap.release()
得到训练数据后,根据其文件名称,给每张图打上标签,为后面做准备
import tensorflow as tf
import numpy as np
import os
import re
# np.set_printoptions(threshold=100000)
'''
功能:获取训练数据地址,存储标签
参数:imagedir
imagedir: 训练数据地址
返回:imagelist, labellist
imagelist:图片位置列表
labellist:数据标签列表
'''
def get_file(imagedir):
images = []
labels = []
for root, dirs, files in os.walk("./train"):
for filename in files:
images.append(os.path.join(root, filename)) # 图片所在目录list
for prefolder in images:
letter = prefolder.split('\\')[-1]
#print(letter)
if re.match('normal', letter): # 匹配图片名称
labels = np.append(labels, [0])
elif re.match('notSuture', letter):
labels = np.append(labels, [1])
elif re.match('malformation', letter):
labels = np.append(labels, [2])
temp = np.array([images, labels]) # 将图片地址和标记存入矩阵中
temp = temp.transpose() # 转置
np.random.shuffle(temp) # 打乱元素
np.random.shuffle(temp) # 打乱元素
np.random.shuffle(temp) # 打乱元素
imagelist = list(temp[:, 0]) # 第一列的所有元素
labellist = list(temp[:, 1])
labellist = [int(float(i)) for i in labellist] # 将标记转化为整形
# print(labellist)
return imagelist, labellist
'''
功能:获取训练数据地址,存储标签
参数:image_list, label_list, img_width, img_height, batch_size, capacity, channel
image_list: 图片位置列表
label_list:数据标签列表
img_width:训练图片size
img_height:训练图片size
batch_size:训练batch size
capacity:线程队列里面包含的数据数量
channel:输入数据通道数
返回:image_batch, label_batch
image_batch:图片batch
label_batch:数据标签batch
'''
def get_batch(image_list, label_list, img_width, img_height