文章目录
项目来源: https://tianchi.aliyun.com/competition/entrance/531795/information
task参考: github链接
本task主要是学习如何使用pytorch进行数据读取与扩充。
常见的对图像数据的读取我们可以采用pillow和opencv库来进行。
一、简单数据读取
pillow读取与保存:
im=Image.open("cat.jpg")
im.save("cat.jpg",jpg)
opencv读取:
img=cv2.imread('cat.jpg')
cv2.imwrite('cat.jpg',jpg)
二、基于pytorch的数据扩充
数据扩增的好处:
1.增加训练样本
2.缓解模型过拟合
3.提高模型的泛化能力
基于torchvision,常见的数据扩增的方法:
- 对图像中心进行剪裁:transforms.CenterCrop
- 对图像颜色的对比度,饱和度和零度进行变换:transforms.ColorJitter
- 对图像四个角和中心进行剪裁得到五分图像:transforms.FiveCrop
- 对图像进行灰度变换:transforms.Grayscale
- 对图像使用固定值进行像素填充:transforms.Pad
- 对图像进行随机仿射变换:transforms.RandomAffine
- 对图像进行随机区域剪裁:transforms.RandomCorp
- 对图像进行随机水平翻转:transforms.RndomHorizontalFlip
- 对图像进行随机旋转:transforms.RndomRotation
- 对图像进行随机垂直翻转:RandomVerticalFlip
常见的数据扩增库:
torchvision
imgaug
albumentations
三、 基于pytorch的 神经网络的基本框架
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
1.定义 dataloader
- 1)主要函数:
tv.datasets: 数据路径,是训练集还是测试集,是否需要下载,预处理
t.utils.data.Dataloader:导⼊数据,制定batch_size
tv.transform.Compose:数据预处理的操作集,如将数据转为Tensor格式,归一化等
主要需要加载训练集和测试集
- 2)一些技巧:
数据加载:如果是常⻅开源库,可以用torchvision载入;
加载完以后可视化看⼀下
2.定义网络(Net)
用一个名