pytorch数据预处理

一,数据加载

数据路径:

#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

class DogCat(data.Dataset):
    def __init__(self, path):
        imgs = os.listdir(path)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs_list_path = [os.path.join(path, i) for i in imgs]

    def __getitem__(self, index):
        img_path = self.imgs_list_path[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        img = t.from_numpy(array)
        return img_path,img, label

    def __len__(self):
        return len(self.imgs_list_path)
if __name__ == '__main__':
    dataset = DogCat('./data/dogcat/')
    # img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)
    print('len(dataset)=',len(dataset))
    for img_path,img, label in dataset:
        print(img_path,img.size(), img.float().mean(), label)

打印结果:

二,数据归一化 

PyTorch提供了torchvision1。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image的操作包括:

 

  • Scale:调整图片尺寸,长宽比保持不变
  • CenterCropRandomCropRandomResizedCrop: 裁剪图片
  • Pad:填充
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]
  • transforms.ColorJitter(0.3, 0.3, 0.2) 颜色抖动
  • transforms.RandomRotation(10)随机旋转

对Tensor的操作包括:

 

  • Normalize:标准化,即减均值,除以标准差
  • ToPILImage:将Tensor转为PIL Image对象
#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
    transforms.CenterCrop(224), # 从图片中间切出224*224的图片
    transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
    #input[channel] = (input[channel] - mean[channel]) / std[channel]
])

class DogCat(data.Dataset):
    def __init__(self, path,transforms=None):
        imgs = os.listdir(path)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs_list_path = [os.path.join(path, i) for i in imgs]
        self.transforms=transforms

    def __getitem__(self, index):
        img_path = self.imgs_list_path[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        if self.transforms:
            pil_img=self.transforms(pil_img)
        array = np.asarray(pil_img)
        img = t.from_numpy(array)
        return img_path,img, label

    def __len__(self):
        return len(self.imgs_list_path)
if __name__ == '__main__':
    dataset = DogCat('./data/dogcat/',transforms=transform)
    # img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)
    print('len(dataset)=',len(dataset))
    for img_path,img, label in dataset:
        print(img_path,img.size(), img.float().mean(), label)

三,利用fer2013数据集进行预处理

数据集地址:https://download.csdn.net/download/fanzonghao/11183885

''' Fer2013 Dataset class'''
from __future__ import print_function
from PIL import Image
import numpy as np
import h5py
import torch.utils.data as data
import cv2
import torchvision.transforms as transforms

# 定义对数据的预处理
transform = transforms.Compose([
        transforms.ToTensor(), # 转为Tensor 归一化至0~1
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
                             ])
class FER2013(data.Dataset):
    """`FER2013 Dataset.

    Args:
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
    """

    def __init__(self, path,split='Training', transform=None):
        self.transform = transform
        self.split = split  # training set or test set
        self.data = h5py.File(path, 'r', driver='core')
        # now load the picked numpy arrays
        if self.split == 'Training':
            self.train_data = self.data['Training_pixel']
            self.train_labels = self.data['Training_label']
            self.train_data = np.asarray(self.train_data)
            self.train_data = self.train_data.reshape((28709, 48, 48))

        elif self.split == 'PublicTest':
            self.PublicTest_data = self.data['PublicTest_pixel']
            self.PublicTest_labels = self.data['PublicTest_label']
            self.PublicTest_data = np.asarray(self.PublicTest_data)
            self.PublicTest_data = self.PublicTest_data.reshape((3589, 48, 48))

        else:
            self.PrivateTest_data = self.data['PrivateTest_pixel']
            self.PrivateTest_labels = self.data['PrivateTest_label']
            self.PrivateTest_data = np.asarray(self.PrivateTest_data)
            self.PrivateTest_data = self.PrivateTest_data.reshape((3589, 48, 48))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.split == 'Training':
            img, target = self.train_data[index], self.train_labels[index]
        elif self.split == 'PublicTest':
            img, target = self.PublicTest_data[index], self.PublicTest_labels[index]
        else:
            img, target = self.PrivateTest_data[index], self.PrivateTest_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = img[:, :, np.newaxis]
        img = np.concatenate((img, img, img), axis=2)
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        if self.split == 'Training':
            return len(self.train_data)
        elif self.split == 'PublicTest':
            return len(self.PublicTest_data)
        else:
            return len(self.PrivateTest_data)

if __name__ == '__main__':
    train_data=FER2013(path='./data/data.h5',split='Training',transform=transform)

    train_loader = data.DataLoader(dataset=train_data,
                                        batch_size=8,
                                        shuffle=True,
                                        num_workers=2)

    print(len(train_data))
    # for i,(img,label) in enumerate(train_data):
    #     if i<1:
    #         img=np.transpose(np.array(img),(1,2,0))
    #         print(img.shape)
    #         img=(img*0.5+0.5)*255
    #         cv2.imwrite('1.jpg',img)
    #         print(label.shape)
    for i,(img, label) in enumerate(train_loader):
        if i<1:
            print('train')
            img=np.transpose(np.array(img)[0],(1,2,0))
            img = (img * 0.5 + 0.5) * 255
            cv2.imwrite('2.jpg',img)

结果:

 

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值