数据预处理

制作数据集

把数据封装成数据集,先自定义一个继承Dataset的class,该class必须重写__getitem__和__len__方法。

from torch.utils.data import Dataset, DataLoader
class Data(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        with open(self.data_path, 'r') as f:
            self.lines = f.readlines()

    # 重写
    def __getitem__(self, index: int):
        """
        根据索引取出数据
        :param index: 根据索引取出数据
        :return:
        """
        # 分割出x, y
        x = self.lines[index].strip().split()
        y = self.lines[index].strip().split()
		
		if self.transform is not None:
       	        trans_img = transforms.Compose([
	                transforms.Resize(size=(self.img_w, self.img_h)),
	                transforms.ToTensor(),  # 0-1
	                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # -1---1
	            ])
	            img = trans_img(img)
	 			
	            trans_label = transforms.Compose([
	                transforms.Resize(size=(self.img_w, self.img_h)),
	                transforms.ToTensor(),
	            ])
            	label = trans_label(label)
        return x, y
    
    # 重写
    def __len__(self):
        # 返回数据长度
        len(self.lines)

dataset = Data('path/to/file', transfrom=None)

实例化得到dataset数据集之后使用DataLoader将数据集设置为batch

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, drop_last=True)

图像增强

在制作dataset时,可以使用transform参数对数据进行一些预处理,把一系列图像的操作封装成一个dict然后作为参数传给自定义Dataset类即可。

data_transforms = {
    "train":
        transforms.Compose([
            transforms.Resize(224),
            transforms.RandomRotation(45),  # 随机旋转
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平旋转
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomGrayscale(p=0.025),  # 随机转换成灰度率
            transforms.ToTensor(),  # 转换成tensor
            transforms.Normalize(mean=[0.484, 0.456, 0.406], std=[0.229, 0.224, 0.255]),  # 图像的mean和std  长度要与图片channel一致
        ]),
    "valid": transforms.Compose([
        transforms.ToTensor(),
        # 验证集的mean和std
        transforms.Normalize(mean=[0.433, 0.436, 0.436], std=[0.223, 0.244, 0.225])
    ])
}

计算图形mean和std的函数:

def mean_std(dataset, data_loader):
    '''
    计算数据集图像的mean和std
    image_data: dataset
    '''
    print('Number of images are', len(dataset))
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in data_loader:
        for i in range(3):
            # b c h w
            mean[i] += X[:, i, :, :].mean()
            std[i] += X[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return list(mean.numpy()), list(std.numpy())
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值