1.问题描述
原文地址:基于CNN的食物分类问题
通过编写CNN进行图片分类,并分辨出食物的类别。
1.1 数据描述
将一系列图片按照[类别_编号]方式进行文件重命名。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
(img-9XZ8wQVG-1597155742646)(https://imgkr2.cn-bj.ufileos.com/57641e91-f43b-4196-86ed-5efbd98b91cf.png?UCloudPublicKey=TOKEN_8d8b72be-579a-4e83-bfd0-5f6ce1546f13&Signature=MvDnKR3CFwYUB23bGQb6JWGnwsk%253D&Expires=1597234953)]
·总的数据大小:训练集大小:9866;测试集大小:3430
2.代码实现
2.1运行过程中所用到的库
# Import需要的库
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time
from tensorboardX import SummaryWriter
2.2通过CV2读入图片
以下用一个例子展示opencv的使用:
通过创建一个名字为Image的画布,并通过cv2读入名称’0_0.jpg’的文件并展示
# 一个展示cv2读取和显示图片的例子
img = cv2.imread(os.path.join(os.path.join('./food-11', "training"), '0_0.jpg'))
cv2.namedWindow("Image")
cv2.imshow("Image", img)
cv2.waitKey (0)
结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UIO6NMKT-1597155742647)(https://imgkr2.cn-bj.ufileos.com/43df7df7-86d5-4542-9839-591d371a0a51.png?UCloudPublicKey=TOKEN_8d8b72be-579a-4e83-bfd0-5f6ce1546f13&Signature=u5F68TtiMdDWH0UCx40jYjMPuz0%253D&Expires=1597235548)]
2.2.1 定义图片读入函数
def readfile(path, label):
# label 是一個 boolean variable,代表需不需要回傳 y 值
image_dir = sorted(os.listdir(path))
x = np.zeros((len(image_dir), 128, 128, 3), dtype=np.uint8)
y = np.zeros((len(image_dir)), dtype=np.uint8)
for i, file in enumerate(image_dir):
img = cv2.imread(os.path.join(path, file))
x[i, :, :] =( cv2.resize(img,(128, 128)))[: , : , : : -1]
if label:
y[i] = int(file.split("_")[0])
if label:
return x, y
else:
return x
代码说明:
1.为减少资源的使用,通过cv2把原图片以[128 * 128 * 3]读入。
2.该函数将input(图片)和ouput(lable)分离
3.注意cv2读入图片是按照[BGR]读入,通过img[: , : , : : -1]可以转换成[RGB]。
2.2.2 训练集和测试集的读入
workspace_dir = './food-11'
print("Reading data")
train_x, train_y = readfile(os.path.join(workspace_dir, "training"), True)
print("Size of training data = {}".format(len(train_x)))
val_x, val_y = readfile(os.path.join(workspace_dir, "validation"), True)
print("Size of test data = {}".format(len(val_x)))
结果:
Reading data
Size of training data = 9866
Size of test data = 3430
2.3 Dataset
2.3.1 ImgDataset类
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Ra