pytorch 记录(一)
pytorch加载数据集
利用官方通道,下载cifar10数据集
预处理
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
对数据进行预处理,相关参数如下:
torchvision.transforms.Scale((227,227))
用来将不同大小的图片resize到统一尺寸。
torchvision.transforms.CenterCrop(size)
将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
torchvision.transforms.RandomCrop(size, padding=0)
切割中心点的位置随机选取。size可以是tuple也可以是Integer。
torchvision.transforms.RandomHorizontalFlip
随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
torchvision.transforms.RandomSizedCrop(size, interpolation=2)
先将给定的PIL.Image随机切,然后再resize成给定的size大小。
torchvision.transforms.ToTensor
是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
torchvision.transforms.Pad(padding, fill=0)
将给定的PIL.Image的所有边用给定的pad value填充。 padding:要填充多少像素 fill:用什么值填充
torchvision.transforms.Normalize(mean,std)
是通过下面公式实现数据归一化
channel=(channel-mean)/std
假设transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) ,transforms.ToTensor()已经把数据处理成[0,1],那么(x-0.5)/0.5就是[-1.0, 1.0]
数据集读取
官方提供的数据集
train_set = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True, +transform=test_transform)
设置加载数据集,为dataset类,其中
root表示数据的加载的相对目录;
train,表示是否加载数据库的训练集,false的时候加载测试集;
download,表示是否自动下载数据集;
transform,表示是否需要对数据进行预处理,none为不进行预处理,transform由自己定义
个人的数据集
需要定义一个继承Dataset的类
首先定义init函数读取图像,其次两个父类函数必须重新加载:
len返回数据集的大小 ;getitem是一个返回函数,返回元组的长度可以是任意长,可以是数据集的索引或者图像。
参考网上诸多代码,由于不同的数据集格式init方法不同,最基本的设计的如下:
(待完善)
class MyDataset(data.Dataset): #子类
def __init__(self,path_file):
self.data_path =path_file
def __getitem__(self, item):
ele = self.data_path[item]
return ele
def __len__(self):
return len(self.data_path)
如果不同类别的数据集在不同的文件夹,可以使用ImageFolder进行加载:
def my_loader(path, batch_size, num_workers, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory)
数据加载
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)
使用DataLoader这个类来更加快捷的对数据进行操作
dataset表示数据包;
batch_size表示每个batch的大小,默认为1;
shuffle表示是否进行shuffle操作,是否打乱数据,默认为False;
num_workers表示加载数据的时候使用几个子进程,默认为0;
sampler表示定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False
drop_last:True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃

本文详细介绍了如何在PyTorch中加载并预处理CIFAR10数据集,包括利用各种图像转换方法如Resize、RandomCrop、RandomHorizontalFlip等进行数据增强,以及如何创建自定义数据集类。
&spm=1001.2101.3001.5002&articleId=86320947&d=1&t=3&u=41cd7cf466354ddf94537f4af07df7e7)
249

被折叠的 条评论
为什么被折叠?



