1 导入库
import torch
import torchvision.transforms as transforms
import os
import numpy as np
from skimage import io, transform
from torch.utils.data import Dataset,DataLoader
2 数据增强【调用numpy的函数,数据类型是np.ndarray】
1)水平随机翻转
class RandomHorizontalFlip():
"""
Args: p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, image):
"""
input: image (array): Image to be flipped.
Returns: image(array): Randomly flipped image.
"""
if torch.rand(1) < self.p:
return np.fliplr(image)
return image
2)垂直随机翻转
class RandomVerticleFlip():
"""
Args: p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, image):
"""
input: image (array): Image to be flipped.
Returns: image(array): Randomly flipped image.
"""
if torch.rand(1) < self.p:
return np.flipup(image)
return image载入数据
3)改变尺寸,缩放
class Rescale(object):#1)
"use this class before 'To tensor'"
'''
size:接受一个元组(a,b)
factor:int或者float,<1表示缩小,>1表示放大
'''
#将其短边统一变成600
def __init__(