基于CNN的图像分类

这篇博客详细介绍了如何使用CNN进行图像分类,特别是在食物类别识别上的应用。作者首先描述了数据集的情况,然后逐步讲解了从图片读取、数据预处理、定义Dataset类,直至构建CNN和全连接网络模型的过程。通过网络可视化展示了模型结构,并分享了训练后的准确率。最后,作者对模型的优化方案进行了探讨,并展望了图像分类在人物-环境分割和人脸识别等领域的潜在应用。
摘要由CSDN通过智能技术生成

1.问题描述


原文地址:基于CNN的食物分类问题


通过编写CNN进行图片分类,并分辨出食物的类别。

1.1 数据描述

将一系列图片按照[类别_编号]方式进行文件重命名。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
(img-9XZ8wQVG-1597155742646)(https://imgkr2.cn-bj.ufileos.com/57641e91-f43b-4196-86ed-5efbd98b91cf.png?UCloudPublicKey=TOKEN_8d8b72be-579a-4e83-bfd0-5f6ce1546f13&Signature=MvDnKR3CFwYUB23bGQb6JWGnwsk%253D&Expires=1597234953)]
·总的数据大小:训练集大小:9866;测试集大小:3430

2.代码实现

2.1运行过程中所用到的库

# Import需要的库
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time

from tensorboardX import SummaryWriter

2.2通过CV2读入图片

以下用一个例子展示opencv的使用:

通过创建一个名字为Image的画布,并通过cv2读入名称’0_0.jpg’的文件并展示

# 一个展示cv2读取和显示图片的例子
img = cv2.imread(os.path.join(os.path.join('./food-11', "training"), '0_0.jpg'))
cv2.namedWindow("Image")
cv2.imshow("Image", img)
cv2.waitKey (0)

结果:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UIO6NMKT-1597155742647)(https://imgkr2.cn-bj.ufileos.com/43df7df7-86d5-4542-9839-591d371a0a51.png?UCloudPublicKey=TOKEN_8d8b72be-579a-4e83-bfd0-5f6ce1546f13&Signature=u5F68TtiMdDWH0UCx40jYjMPuz0%253D&Expires=1597235548)]

2.2.1 定义图片读入函数

def readfile(path, label):
    # label 是一個 boolean variable,代表需不需要回傳 y 值
    image_dir = sorted(os.listdir(path))
    x = np.zeros((len(image_dir), 128, 128, 3), dtype=np.uint8)
    y = np.zeros((len(image_dir)), dtype=np.uint8)
    for i, file in enumerate(image_dir):
        img = cv2.imread(os.path.join(path, file))
        x[i, :, :] =( cv2.resize(img,(128, 128)))[: , : , : : -1]
        if label:
          y[i] = int(file.split("_")[0])
    if label:
      return x, y
    else:
      return x

代码说明:

1.为减少资源的使用,通过cv2把原图片以[128 * 128 * 3]读入。

2.该函数将input(图片)和ouput(lable)分离

3.注意cv2读入图片是按照[BGR]读入,通过img[: , : , : : -1]可以转换成[RGB]。

2.2.2 训练集和测试集的读入

workspace_dir = './food-11'
print("Reading data")
train_x, train_y = readfile(os.path.join(workspace_dir, "training"), True)
print("Size of training data = {}".format(len(train_x)))
val_x, val_y = readfile(os.path.join(workspace_dir, "validation"), True)
print("Size of test data = {}".format(len(val_x)))

结果:

Reading data
Size of training data = 9866
Size of test data = 3430

2.3 Dataset

2.3.1 ImgDataset类

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Ra
  • 4
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

不想当韭菜啊!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值