Task2 数据读取与数据扩增

数据增广:
主要用来增加训练数据量,丰富数据多样性,从而达到降低训练模型的过拟合性,增强模型的范化能力。pytorch中的torchvision.transforms模块中有方便易用的图形变换功能,可以用来进行图形数据的增广,现将代码和运行效果列出:

import torch
from matplotlib import pyplot as plt
from PIL import Image
import torchvision.transforms as transforms

img_path = '../input/train/train/000104.png'
img = Image.open(img_path).convert('RGB')
resize = transforms.Compose([transforms.Resize((130,300))])
img_r = resize(img)
plt.figure(figsize=(18,10))
plt.subplot(4,3,1)
plt.title('origin')
plt.imshow(img_r)
plt.axis('off')

tran_VF = transforms.Compose([transforms.RandomVerticalFlip(1.0)])
img_VF = tran_VF(img_r)
plt.subplot(4,3,2)
plt.title("RandomVerticalFlip")
plt.imshow(img_VF)
plt.axis('off')


tran_Rota = transforms.Compose([transforms.RandomRotation(30)])
img_Rota = tran_Rota(img_r)
plt.subplot(4,3,3)
plt.title("RandomRotation")
plt.imshow(img_Rota)
plt.axis('off')


tran_ReCrop = transforms.Compose([transforms.RandomResizedCrop(65)])
img_ReCrop = tran_ReCrop(img_r)
plt.subplot(4,3,4)
plt.title("RandomResizedCrop")
plt.imshow(img_ReCrop)
plt.axis('off')

tran_RaHorizon = transforms.Compose([transforms.RandomHorizontalFlip(1.0)])
img_RaHorizon = tran_RaHorizon(img_r)
plt.subplot(4,3,5)
plt.title("RandomHorizontalFlip")
plt.imshow(img_RaHorizon)
plt.axis('off')

tran_RaGray = transforms.Compose([transforms.RandomGrayscale(1.0)])
img_RaGray = tran_RaGray(img_r)
plt.subplot(4,3,6)
plt.title("RandomGrayscale")
plt.imshow(img_RaGray)
plt.axis('off')


tran_RaCrop = transforms.Compose([transforms.RandomCrop((64,130))])
img_RaCrop = tran_RaCrop(img_r)
plt.subplot(4,3,7)
plt.title("RandomCrop")
plt.imshow(img_RaCrop)
plt.axis('off')

tran_RaAffi = transforms.Compose([transforms.RandomAffine(degrees=0,translate=(0.0,0.0),shear=20)])
img_RaAffi = tran_RaAffi(img_r)
plt.subplot(4,3,8)
plt.title("RandomAffine")
plt.imshow(img_RaAffi)
plt.axis('off')


tran_Grayscale = transforms.Compose([transforms.Grayscale(3)])
img_Grayscale = tran_Grayscale(img_r)
plt.subplot(4,3,9)
plt.title("Grayscale")
plt.imshow(img_Grayscale)
plt.axis('off')

tran_CenCrop = transforms.Compose([transforms.CenterCrop((60,120))])
img_CenCrop = tran_CenCrop(img_r)
plt.subplot(4,3,10)
plt.title("CenterCrop")
plt.imshow(img_CenCrop)
plt.axis('off')

tran_ColoJit = transforms.Compose([transforms.ColorJitter(0.9,0.3,0.3,0.3)])
img_ColoJit = tran_ColoJit(img_r)
plt.subplot(4,3,11)
plt.title("ColorJitter")
plt.imshow(img_ColoJit)
plt.axis('off')

在这里插入图片描述数据读取:
pytorch中进行数据读取,主要涉及到torch.utils.data.Dataset和torch.utils.data.DataLoader两个类,Dataset是个抽象类,实现Dataset的子类,即可以通过关键词或索引值来获取数据样本,Dataset类中的__getitem__()方法必须重载,__len()__方法是否重载随意。
DataLoader结合Dataset,便可对数据集进行迭代访问。

展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 游动-白 设计师: 上身试试
应支付0元
点击重新获取
扫码支付

支付成功即可阅读