本笔记主要学习的是《深度学习框架PyTorch:入门与实践》
本节的笔记是学习 文件组织架构模型定义
数据处理和加载
训练模型(Train&Validate)
训练过程的可视化
测试(Test/Inference)
另外程序还应该满足以下几个要求:模型需具有高度可配置性,便于修改参数、修改模型,反复实验
代码应具有良好的组织结构,使人一目了然
代码应具有良好的说明,使其他人能够理解
首先来看程序文件的组织结构:
├── checkpoints/
├── data/
│ ├── __init__.py
│ ├── dataset.py
│ └── get_data.sh
├── models/
│ ├── __init__.py
│ ├── AlexNet.py
│ ├── BasicModule.py
│ └── ResNet34.py
└── utils/
│ ├── __init__.py
│ └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.mdcheckpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
data/:数据相关操作,包括数据预处理、dataset实现等
models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
requirements.txt:程序依赖的第三方库
README.md:提供程序的必要说明
加载数据
Dataset提供数据集的封装,再使用Dataloader实现数据并行加载。写一个class继承Dataset。对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。将文件读取等费时操作放在__getitem__函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""目标:获取所有图片地址,并根据训练、验证、测试划分数据"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
# 划分训练、验证集,验证:训练 = 3:7
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]
if transforms is None:
# 数据转换操作,测试验证和训练的数据转换有所区别
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
# 测试集和验证集
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
# 训练集
else :
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""返回一张图片的数据对于测试集,没有label,返回图片id,如1000.jpg返回1000"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)