torch 自用 dataloader

使用torchvison.transform进行预处理 和 数据增强

from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

def load_dataset(batch_size,train_dir,vali_dir):
    transform = transforms.Compose([
        transforms.Resize((36,36)),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomHorizontalFlip(0.5),

                            ])
    val_transform = transforms.Compose([
        transforms.Resize((36,36)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

                            ])
    train_data = torchvision.datasets.ImageFolder(train_dir,transform=transform)
    train_loader = DataLoader(train_data,batch_size,shuffle=True,num_workers=1)

    vali_data = torchvision.datasets.ImageFolder(vali_dir,transform=val_transform)
    vali_loader = DataLoader(vali_data,batch_size,shuffle=True,num_workers=1)

    return train_loader,vali_loader,train_data,vali_data
if __name__ == '__main__':
    dir1 = 'train'
    dir2 = 'vali'
    train_loader,vali_loader,train_data,vali_data = load_dataset(1,dir1,dir2)
    print(1)

使用albumentations来进行数据增强 ,这个更好扩展

import torch
from torch.utils.data import Dataset
import cv2
import os
from albumentations import *
import math

from torch.utils.data.dataloader import DataLoader

def creat_dataLoader(path,bs=1,pattern='train'):
    dataset = MyDataset(path,pattern=pattern)
    dataloader = DataLoader(dataset,batch_size=bs,num_workers=1,shuffle=True)
    return dataset,dataloader
class MyDataset(Dataset):
    def __init__(self, src_dir,pattern):
        self.imgs = []
        self.label = []
        self.pattern = pattern
        for roots, dirs, files in os.walk(src_dir):
            if roots == src_dir:
                self.class2ids = {dirs[i]: i for i in range(len(dirs))}
            for file in files:
                file_path = roots + '/' + file
                self.imgs.append(file_path)
                self.label.append(self.class2ids[roots.split('\\')[-1]])

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, id):
        img = _load_image(self, id)
        return torch.tensor(img).float(), torch.tensor(self.label[id]).long()


def resize(img):
    h, w, _ = img.shape
    target_side = 36
    if h / w < 0.5:
        ratio = w / target_side
        target_h = math.ceil(h / ratio)
        img = cv2.resize(img, (target_side, target_h))
        top_pad_side = (target_side - target_h) // 2
        bottom_pad_side = target_side - target_h - top_pad_side
        img = cv2.copyMakeBorder(img, top_pad_side, bottom_pad_side, 0, 0, cv2.BORDER_CONSTANT, (0, 0, 0,))
    elif w / h < 0.5:
        ratio = h / target_side
        target_w = math.ceil(w / ratio)
        img = cv2.resize(img, (target_w, target_side))
        left_pad_side = (target_side - target_w) // 2
        right_pad_side = target_side - target_w - left_pad_side
        img = cv2.copyMakeBorder(img, 0, 0, left_pad_side, right_pad_side, cv2.BORDER_CONSTANT, (0, 0, 0,))
    else:
        img = cv2.resize(img, (target_side, target_side))
    return img


def _preprocess(img,pattern):
    img = resize(img)
    if pattern == 'train':
        img = Compose([
            HorizontalFlip(p=0.5),
            #
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise()],
                p=0.3),
            #
            OneOf([
                MotionBlur(p=0.2),
                MedianBlur(blur_limit=3, p=0.15),
                Blur(blur_limit=3, p=0.15)],
                p=0.2),
            #
            OneOf([
                OpticalDistortion(p=0.3),
            ], p=0.3),
            #
            OneOf([
                CLAHE(clip_limit=2),
                IAASharpen(),
                IAAEmboss(),
                RandomBrightnessContrast(),
            ], p=0.3),
            #
            HueSaturationValue(p=0.3)
        ], p=1)(image=img)['image']
    img = (img/255.0-np.array([0.485, 0.456, 0.406]))/np.asarray([0.229, 0.224, 0.225])
    return np.transpose(img,(2,0,1))


def _load_image(self, id):
    img = cv2.imread(self.imgs[id])
    img = _preprocess(img,pattern=self.pattern)
    return img


if __name__ == '__main__':
    data_set, data_loder = creat_dataLoader('train')
    print(1)
    for i, (imgs, labels) in enumerate(data_loder):

        print(1)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值