[十]深度学习Pytorch-transforms图像操作及自定义方法(含数据增强实战)

0. 往期内容

[一]深度学习Pytorch-张量定义与张量创建

[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换

[三]深度学习Pytorch-张量数学运算

[四]深度学习Pytorch-线性回归

[五]深度学习Pytorch-计算图与动态图机制

[六]深度学习Pytorch-autograd与逻辑回归

[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)

[八]深度学习Pytorch-图像预处理transforms

[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)

[十]深度学习Pytorch-transforms图像操作及自定义方法

在这里插入图片描述

1.transforms-图像变换

1.1 transforms.Pad(padding, fill=0, padding_mode=‘constant’)

transforms.Pad(padding, fill=0, padding_mode='constant')

(1)功能:对图片边缘进行填充;
(2)参数
padding: 设置填充大小:
I. 当paddinga时,左右上下均填充a个像素;
II. 当padding(a,b)时,左右填充a个像素,上下填充b个像素;
III. 当padding(a,b,c,d)时,左、上、右、下分别填充a、b、c、d
padding_mode:填充模式,有4种模式:
I. constant:像素值由fill设定;
II. edge:像素值由图像边缘的像素值决定;
III. reflect:镜像填充,最后一个像素不镜像,eg. [1,2,3,4] --> [3,2,1,2,3,4,3,2];
在这里插入图片描述
向左:由于1不会镜像,所以左边镜像2、3
向右:由于4不会镜像,所以右边镜像3、2

IV. symmetric:镜像填充,最后一个像素镜像,eg. [1,2,3,4] --> [2,1,1,2,3,4,4,3];
在这里插入图片描述
向左:1、2镜像
向右:4、3镜像

fill:padding_mode='constant'时,用于设置填充的像素值,(R,G,B) or (Gray)

(3)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 1 Pad
    transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

1.2 transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

(1)功能:调整图片的亮度、对比度、饱和度和色相;
(2)参数
brightness: 亮度调整因子,brightness > 1会更亮,brightness < 1会更暗;
I. 当brightnessa时,从区间[max(0,1-a),1+a]中随机选择;
II. 当brightness(a,b)时,从区间[a,b]中随机选择;
contrast: 对比度参数,同brightness,对比度越低,图像越灰;
saturation: 饱和度参数,同brightness,饱和度越低,图像越暗淡;
hue: 色相参数;
I. 当huea时,从[-a,a]中随机选择参数,注意a的区间是0 ≤ a ≤ 0.5;
II. 当hue(a,b)时,从[a,b]区间中随机选择参数,注意-0.5 ≤ a ≤ b ≤ 0.5;
(3)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 2 ColorJitter
    transforms.ColorJitter(brightness=0.5),
    transforms.ColorJitter(contrast=0.5),
    transforms.ColorJitter(saturation=0.5),
    transforms.ColorJitter(hue=0.3),

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

1.3 transforms.Grayscale(num_output_channels)

transforms.Grayscale(num_output_channels)

(1)功能:将图片转换为灰度图;
(2)参数
num_output_channels: 输出通道数,只能设置为13;
(3)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 3 Grayscale
    transforms.Grayscale(num_output_channels=3),

    transforms.ToTensor(), 
    transforms.Normalize(norm_mean, norm_std),
])

1.4 transforms.RandomGrayscale(num_output_channels, p=0.1)

transforms.RandomGrayscale(num_output_channels, p=0.1)

(1)功能:根据概率将图片转换为灰度图;
(2)参数
num_output_channels: 输出通道数,只能设置为13;
p: 概率值,图像被转换为灰度图的概率;

1.5 transforms.RandomAffine(degrees, fillcolor=0)

transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)

(1)功能: 对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转;
(2)参数
degrees: 旋转角度;
degrees旋转是中心旋转,degrees参数必须设置,不想旋转的话设置degrees=0;
I. 当degreesa时,在区间(-a,a)之间随机选择旋转角度;
II. 当degrees(a,b)时,在区间(a,b)之间随机选择旋转角度;
translate: 平移区间设置,如果为(a,b)a设置宽widthb设置高height,图像在宽维度平移的区间为-img_width * a < dx < img_width * a,在高维度平移的区间为-img_height * a < dy < img_height *a
scale: 缩放比例(以面积为单位),scale区间范围是[0,1]
fill_color: 填充颜色设置
shear: 错切角度设置,有水平错切和垂直错切;
I. 若sheara,则仅在x轴错切,错切角度在区间(-a,a)之间随机选择;
II. 若shear(a,b),则a设置x轴错切角度,即区间(-a,a)之间随机选择,b设置y轴错切角度,即区间(-b,b)之间随机选择;
resample: 重采样方式,有NEAREST、BILINEAR、BICUBIC
(3)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 4 Affine
    transforms.RandomAffine(degrees=30),
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), fillcolor=(255, 0, 0)), #degrees参数必须设置,不想旋转的话设置degrees=0
    transforms.RandomAffine(degrees=0, scale=(0.7, 0.7)), #不足的地方面积填充为黑色
    transforms.RandomAffine(degrees=0, shear=(0, 0, 0, 45)), #y轴上错切
    transforms.RandomAffine(degrees=0, shear=90, fillcolor=(255, 0, 0)),

    transforms.ToTensor(), 
    transforms.Normalize(norm_mean, norm_std),
])

在这里插入图片描述原图-x轴错切-y轴错切
哪个轴平行就是沿着哪个轴错切!

1.6 transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0)

transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)

(1)功能:对图片进行随机遮挡;
(2)参数
p: 概率值,图像被遮挡的概率;
scale: 遮挡区域的比例(以面积为单位);
ratio: 遮挡区域的长宽比;
value: 设置遮挡区域的像素值,eg. (R,G,B) or (Gray);
(3)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 5 Erasing
    # RandomErasing接受的是张量,所以需要先进行ToTensor()操作
    transforms.ToTensor(),
    transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=(254/255, 0, 0)), #value=(254/255, 0, 0),此时为张量,需要进行归一化,除以255变换到0-1范围
    transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='1234'), #value='1234'value为任意字符串时,则填充的为随机彩色像素值
 
    transforms.Normalize(norm_mean, norm_std),
])

RandomErasing接受的是张量,所以需要先进行ToTensor()操作;
value='1234’value为任意字符串时,则填充的为随机彩色像素值。

1.7 transforms.Lambda(lambd)

transforms.Lambda(lambd)

(1)功能:用户自定义lambda方法;
(2)参数
lambd: lambda匿名函数,用法如下:

lambda [arg1 [, arg2, ..., argn]]: expression

(3)代码示例:

transforms.FiveCrop(112), #单独使用错误,直接使用transforms.FiveCrop(112)会报错,需要跟下一行一起使用
#lamda的冒号之前是函数的输入(crops),冒号之后是函数的返回值
transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])), #这里进行了ToTensor(),后面不需要执行Totensor()和Normalize

Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops]))进行了ToTensor(),后面不需要执行Totensor()Normalize

2. transforms方法操作

2.1 transforms.RandomChoice([transforms1, transforms2, transforms3])

transforms.RandomChoice([transforms1, transforms2, transforms3])

(1)功能: 从一系列transforms方法中随机选择一个执行; 执行一个
(2)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 1 RandomChoice
    transforms.RandomChoice([transforms.RandomVerticalFlip(p=1), transforms.RandomHorizontalFlip(p=1)]),

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

2.2 transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

(1)功能: 根据概率执行该组transforms执行一组
(2)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 2 RandomApply
    transforms.RandomApply([transforms.RandomAffine(degrees=0, shear=45, fillcolor=(255, 0, 0)), 
                            transforms.Grayscale(num_output_channels=3)], p=0.5),

    transforms.ToTensor(), 
    transforms.Normalize(norm_mean, norm_std),
])

2.3 transforms.RandomOrder([transforms1, transforms2, transforms3])

transforms.RandomOrder([transforms1, transforms2, transforms3])

(1)功能: 对一组transforms打乱顺序并执行一组; 执行一组
(2)代码示例

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 3 RandomOrder
    transforms.RandomOrder([transforms.RandomRotation(15),
                            transforms.Pad(padding=32),
                            transforms.RandomAffine(degrees=0, translate=(0.01, 0.1), scale=(0.9, 1.1))]),

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

2.4 完整代码示例

transforms_methods_2.py

# -*- coding: utf-8 -*-
"""
# @file name  : transforms_methods_2.py
# @brief      : transforms方法二
"""
import os
import numpy as np
import torch
import random
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert



def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 1 Pad
    # transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
    # transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
    # transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
    # transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),

    # 2 ColorJitter
    # transforms.ColorJitter(brightness=0.5),
    # transforms.ColorJitter(contrast=0.5),
    # transforms.ColorJitter(saturation=0.5),
    # transforms.ColorJitter(hue=0.3),

    # 3 Grayscale
    # transforms.Grayscale(num_output_channels=3),

    # 4 Affine
    # transforms.RandomAffine(degrees=30),
    # transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), fillcolor=(255, 0, 0)), #degrees参数必须设置,不想旋转的话设置degrees=0
    # transforms.RandomAffine(degrees=0, scale=(0.7, 0.7)), #不足的地方面积填充为黑色
    # transforms.RandomAffine(degrees=0, shear=(0, 0, 0, 45)), #y轴上错切
    # transforms.RandomAffine(degrees=0, shear=90, fillcolor=(255, 0, 0)),

    # 5 Erasing
    # RandomErasing接受的是张量,所以需要先进行ToTensor()操作,注释掉89行
    # transforms.ToTensor(),
    # transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=(254/255, 0, 0)), #value=(254/255, 0, 0),此时为张量,需要进行归一化,除以255变换到0-1范围
    # transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='1234'), #value='1234'value为任意字符串时,则填充的为随机彩色像素值

    # 1 RandomChoice
    # transforms.RandomChoice([transforms.RandomVerticalFlip(p=1), transforms.RandomHorizontalFlip(p=1)]),

    # 2 RandomApply
    # transforms.RandomApply([transforms.RandomAffine(degrees=0, shear=45, fillcolor=(255, 0, 0)), 
    #                         transforms.Grayscale(num_output_channels=3)], p=0.5),
    # 3 RandomOrder
    # transforms.RandomOrder([transforms.RandomRotation(15),
    #                         transforms.Pad(padding=32),
    #                         transforms.RandomAffine(degrees=0, translate=(0.01, 0.1), scale=(0.9, 1.1))]),

    transforms.ToTensor(), #若使用RandomErasing,则注释掉该行
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

3. 自定义transforms方法

3.1 自定义transforms要素

(1)仅接受一个参数,返回一个参数;
在这里插入图片描述(2)注意上下游的输出与输入,比如是PIL Image还是Tensor。

3.2 通过类实现多参数传入

在这里插入图片描述

3.3 椒盐噪声

在这里插入图片描述
在这里插入图片描述

3.4 代码示例

my_transforms.py

# -*- coding: utf-8 -*-
"""
# @file name  : my_transforms.py
# @brief      : 自定义一个transforms方法
"""
import os
import numpy as np
import torch
import random
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


class AddPepperNoise(object):
    """增加椒盐噪声
    Args:
        snr (float): Signal Noise Rate 信噪比
        p (float): 概率值,依概率执行该操作
    """

    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) or (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy() #PIL Image转成ndarray
            h, w, c = img_.shape #高 宽 chanel数量
            signal_pct = self.snr #信号的百分比
            noise_pct = (1 - self.snr) #噪声的百分比
            mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.]) #0表示原噪声,1表示盐噪声,2表示椒噪声,(0,1,2)是为了构造mask
            mask = np.repeat(mask, c, axis=2)
            img_[mask == 1] = 255   # 盐噪声,白色
            img_[mask == 2] = 0     # 椒噪声,黑色
            return Image.fromarray(img_.astype('uint8')).convert('RGB') #ndarray转成PIL Image
        else:
            return img


# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddPepperNoise(0.9, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

4. 数据增强实战

在这里插入图片描述
在这里插入图片描述RMB_data_augmentation.py

# -*- coding: utf-8 -*-
"""
# @file name  : RMB_data_augmentation.py
# @brief      : 人民币分类模型数据增强实验
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.9),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])


valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss_val)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100

    img_tensor = inputs[0, ...]  # C H W
    img = transform_invert(img_tensor, train_transform)
    plt.imshow(img)
    plt.title("LeNet got {} Yuan".format(rmb))
    plt.show()
    plt.pause(0.5)
    plt.close()

tools/my_dataset.py

# -*- coding: utf-8 -*-
"""
# @file name  : dataset.py
# @brief      : 各数据集的Dataset定义
"""

import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"1": 0, "100": 1}


class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info

tools/common_tools.py

# -*- coding: utf-8 -*-
"""
# @file name  : common_tools.py
# @brief      : 通用函数
"""

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image


def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])

    img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    if 'ToTensor' in str(transform_train):
        img_ = np.array(img_) * 255

    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_
  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值