制作数据集
把数据封装成数据集,先自定义一个继承Dataset的class,该class必须重写__getitem__和__len__方法。
from torch.utils.data import Dataset, DataLoader
class Data(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
with open(self.data_path, 'r') as f:
self.lines = f.readlines()
# 重写
def __getitem__(self, index: int):
"""
根据索引取出数据
:param index: 根据索引取出数据
:return:
"""
# 分割出x, y
x = self.lines[index].strip().split()
y = self.lines[index].strip().split()
if self.transform is not None:
trans_img = transforms.Compose([
transforms.Resize(size=(self.img_w, self.img_h)),
transforms.ToTensor(), # 0-1
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # -1---1
])
img = trans_img(img)
trans_label = transforms.Compose([
transforms.Resize(size=(self.img_w, self.img_h)),
transforms.ToTensor(),
])
label = trans_label(label)
return x, y
# 重写
def __len__(self):
# 返回数据长度
len(self.lines)
dataset = Data('path/to/file', transfrom=None)
实例化得到dataset数据集之后使用DataLoader将数据集设置为batch
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, drop_last=True)
图像增强
在制作dataset时,可以使用transform参数对数据进行一些预处理,把一系列图像的操作封装成一个dict然后作为参数传给自定义Dataset类即可。
data_transforms = {
"train":
transforms.Compose([
transforms.Resize(224),
transforms.RandomRotation(45), # 随机旋转
transforms.RandomHorizontalFlip(p=0.5), # 随机水平旋转
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomGrayscale(p=0.025), # 随机转换成灰度率
transforms.ToTensor(), # 转换成tensor
transforms.Normalize(mean=[0.484, 0.456, 0.406], std=[0.229, 0.224, 0.255]), # 图像的mean和std 长度要与图片channel一致
]),
"valid": transforms.Compose([
transforms.ToTensor(),
# 验证集的mean和std
transforms.Normalize(mean=[0.433, 0.436, 0.436], std=[0.223, 0.244, 0.225])
])
}
计算图形mean和std的函数:
def mean_std(dataset, data_loader):
'''
计算数据集图像的mean和std
image_data: dataset
'''
print('Number of images are', len(dataset))
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in data_loader:
for i in range(3):
# b c h w
mean[i] += X[:, i, :, :].mean()
std[i] += X[:, i, :, :].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return list(mean.numpy()), list(std.numpy())