使用torchvison.transform进行预处理 和 数据增强
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
def load_dataset(batch_size,train_dir,vali_dir):
transform = transforms.Compose([
transforms.Resize((36,36)),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomHorizontalFlip(0.5),
])
val_transform = transforms.Compose([
transforms.Resize((36,36)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data = torchvision.datasets.ImageFolder(train_dir,transform=transform)
train_loader = DataLoader(train_data,batch_size,shuffle=True,num_workers=1)
vali_data = torchvision.datasets.ImageFolder(vali_dir,transform=val_transform)
vali_loader = DataLoader(vali_data,batch_size,shuffle=True,num_workers=1)
return train_loader,vali_loader,train_data,vali_data
if __name__ == '__main__':
dir1 = 'train'
dir2 = 'vali'
train_loader,vali_loader,train_data,vali_data = load_dataset(1,dir1,dir2)
print(1)
使用albumentations来进行数据增强 ,这个更好扩展
import torch
from torch.utils.data import Dataset
import cv2
import os
from albumentations import *
import math
from torch.utils.data.dataloader import DataLoader
def creat_dataLoader(path,bs=1,pattern='train'):
dataset = MyDataset(path,pattern=pattern)
dataloader = DataLoader(dataset,batch_size=bs,num_workers=1,shuffle=True)
return dataset,dataloader
class MyDataset(Dataset):
def __init__(self, src_dir,pattern):
self.imgs = []
self.label = []
self.pattern = pattern
for roots, dirs, files in os.walk(src_dir):
if roots == src_dir:
self.class2ids = {dirs[i]: i for i in range(len(dirs))}
for file in files:
file_path = roots + '/' + file
self.imgs.append(file_path)
self.label.append(self.class2ids[roots.split('\\')[-1]])
def __len__(self):
return len(self.imgs)
def __getitem__(self, id):
img = _load_image(self, id)
return torch.tensor(img).float(), torch.tensor(self.label[id]).long()
def resize(img):
h, w, _ = img.shape
target_side = 36
if h / w < 0.5:
ratio = w / target_side
target_h = math.ceil(h / ratio)
img = cv2.resize(img, (target_side, target_h))
top_pad_side = (target_side - target_h) // 2
bottom_pad_side = target_side - target_h - top_pad_side
img = cv2.copyMakeBorder(img, top_pad_side, bottom_pad_side, 0, 0, cv2.BORDER_CONSTANT, (0, 0, 0,))
elif w / h < 0.5:
ratio = h / target_side
target_w = math.ceil(w / ratio)
img = cv2.resize(img, (target_w, target_side))
left_pad_side = (target_side - target_w) // 2
right_pad_side = target_side - target_w - left_pad_side
img = cv2.copyMakeBorder(img, 0, 0, left_pad_side, right_pad_side, cv2.BORDER_CONSTANT, (0, 0, 0,))
else:
img = cv2.resize(img, (target_side, target_side))
return img
def _preprocess(img,pattern):
img = resize(img)
if pattern == 'train':
img = Compose([
HorizontalFlip(p=0.5),
#
OneOf([
IAAAdditiveGaussianNoise(),
GaussNoise()],
p=0.3),
#
OneOf([
MotionBlur(p=0.2),
MedianBlur(blur_limit=3, p=0.15),
Blur(blur_limit=3, p=0.15)],
p=0.2),
#
OneOf([
OpticalDistortion(p=0.3),
], p=0.3),
#
OneOf([
CLAHE(clip_limit=2),
IAASharpen(),
IAAEmboss(),
RandomBrightnessContrast(),
], p=0.3),
#
HueSaturationValue(p=0.3)
], p=1)(image=img)['image']
img = (img/255.0-np.array([0.485, 0.456, 0.406]))/np.asarray([0.229, 0.224, 0.225])
return np.transpose(img,(2,0,1))
def _load_image(self, id):
img = cv2.imread(self.imgs[id])
img = _preprocess(img,pattern=self.pattern)
return img
if __name__ == '__main__':
data_set, data_loder = creat_dataLoader('train')
print(1)
for i, (imgs, labels) in enumerate(data_loder):
print(1)