HW3--CNN
如果你的显卡是NIVDIA啥的,那建议你搜索一下安装GPU版本的pytorch,如果不是,可以参考一下这一篇,我现在用的不多,只是用懂了基础篇
作业要求
写一个CNN,把食物进行分类
所给的数据中均是食物的照片,共有11类,Bread, Dairy product, Dessert, Egg, Fried food, Meat, Noodles/Pasta, Rice, Seafood, Soup, and Vegetable/Fruit.
代码部分
此代码是李宏毅老师公开的,在此仅加入了个人理解的注释
加载数据集(colab的需要)
具体的数据集我也没有,一个是压缩包900多兆,下载没几分钟就说失败,另一个是我懒
这里还好,给了好几个存放地,反正图上这个我运行的时候是崩了,换了一个才成功
# Download the dataset
# You may choose where to download the data.
# Google Drive
!gdown --id '1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy' --output food-11.zip
# Dropbox
# !wget https://www.dropbox.com/s/m9q6273jl3djall/food-11.zip -O food-11.zip
# MEGA
# !sudo apt install megatools
# !megadl "https://mega.nz/#!zt1TTIhK!ZuMbg5ZjGWzWX1I6nEUbfjMZgCmAgeqJlwDkqdIryfg"
# Unzip the dataset.
# This may take some time.
!unzip -q food-11.zip
成功下载读取数据是类似这个样子
import packages
# Import necessary packages.
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder
# This is for the progress bar.
from tqdm.auto import tqdm
'''这个运行显示的进度条好歹能给你点安慰'''
处理数据
# It is important to do data augmentation in training.
# However, not every augmentation is useful.
# Please think about what kind of augmentation is helpful for food recognition.
'''
这一块属于数据扩展,但是具体方法要自己写
torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起
'''
train_tfm = transforms.Compose([
# Resize the image into a fixed shape (height = width = 128)
transforms.Resize((128, 128)),
# You may add some transforms here.
# ToTensor() should be the last one of the transforms.
'''
transforms.RandomHorizontalFlip(), # 随机将图片水平翻转
transforms.RandomRotation(15), # 随机旋转图片
或者
def augment(self, image, flipCode):
# 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
flip = cv2.flip(image, flipCode)
return flip
# 随机进行数据增强,为2时不做处理
flipCode = random.choice([-1, 0, 1, 2])#.choice随机选一个
if flipCode != 2:
image = self.augment(image, flipCode)
label = self.augment(label, flipCode)
return image, label
'''
transforms.ToTensor(),
])
# We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
'''ToTensor将一个tensor image根据其均值和方差进行归一化
'''
])
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 128
# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder(