图像去雨:pytorch 输入和标签都是图片的数据集

背景

在研究使用深度学习的方法进诸如图像去噪、图像去雾、图像去马赛克等需求的时候,经常是受污染图片和干净图片互相对应,此时就需要制作这种成对应关系的数据集。本文使用的读取图片的方法是 PIL 库里的 Image.open(),初步学习,不对的地方还请指正。

简单步骤

  1. 获取两个文件夹下的所有图片(建议图片名称对应完全一样,且没有汉字);
  2. 保持对应关系分别获取每个图片的具体路径;
  3. 读取图片数据;
  4. 进行必要的变换操作(根据读取图片的方法不同有异)
  5. 返回受污染图片和相应的标签图片。
  6. 开始炼丹。

代码实现

  1. 以下代码在DataTrain.py里,功能是对图片数据进行读入和一定操作,方便后续使用torch.utils.data.DataLoader()方法获取数据集。
import torchvision
from torch.utils.data import  Dataset
import os
from PIL import Image
class MyDataset(Dataset): # 继承 Dataset 类
    def __init__(self, input_path, label_path):
        self.input_path = input_path # 受污染图片所在文件夹
        self.input_path_image = os.listdir(input_path) # 文件夹下的所有图片对象

        self.label_path = label_path # 干净图片所在文件夹
        self.label_path_image = os.listdir(label_path)
		
		# 定义要对图片进行的变换
        self.transforms = torchvision.transforms.Compose([
       		 # 中心裁剪64*64大小作为pacth
            torchvision.transforms.CenterCrop([64, 64]), 
            
            # 将读入的数据归一化[0, 1]之间并变为张量类型
            torchvision.transforms.ToTensor(), 
            ])

    def __len__(self):
        return len(self.input_path_image) # 返回长度
 
    def __getitem__(self, index):
    	# index 索引对应的受污染图片完整路径
        input_image_path = os.path.join(self.input_path, self.input_path_image[index])
        # 利用PIL.Image 读入图片数据并转换通道结构
        input_image = Image.open(input_image_path).convert('RGB')

        label_image_path = os.path.join(self.label_path, self.label_path_image[index])
        label_image = Image.open(label_image_path).convert('RGB')

		# 对读入的图片进行固定的变换
        input = self.transforms(input_image)
        label = self.transforms(label_image)

        return  (input, label) # 返回适合在网络中训练的图片数据
  1. 准备好图片数据之后就可以准备送入网络进行训练了,以下代码在 train.py 里用来获取数据集用作训练。
from DataTrain import MyDataset # 从前述 py 文件里导入 MyDataset 类
from torch.utils.data import DataLoader
from torch.autograd import Variable

BATCH_SIZE = 100 # 参与每次训练的数量,就是将数据集按照 BATCH_SIZE 大小进行拆分

dataset = MyDataset(input_path, label_path) # 实现前述定义的 MyDataset 类
data_train = DataLoader(dataset, batch_size=BATCH_SIZE,shuffle=True) # 获取训练数据集,拆分,打乱

# 这个循环一般在 epoch 循环下表示一次训练,x,y 对应前述返回的 input,label
for j, (x, y) in enumerate(data_train): 
	input = Variable(x).cuda() # 转换数据为 GPU 变量
	label = Variable(y).cuda()

# 接下来 input,label 就可以送入网络进行训练了

其他

PIL 库读取图片

  1. 图片格式:
    PIL 库读取图片后本就是 RGB 形式,但有时训练时会报错,具体原因不清楚,故获取图片后一般再转成.convert('RGB')一下。
  2. 数据格式
    PIL库读取图片数据是 uint8 类型的,不适合直接用作网络训练,要进行归一化转为 float 类型。

无法从其他 py 文件里导入

在工程文件下新建一个 python package,将需要导入的 .py 文件放进去,之后,右击该包,如下操作即可。
标记为 sources root 即可

为什么CenterCrop(),而不是Resize()?

在我阅读的大部分源码中,很多人都进行的是Resize()操作,我认为在图像去雨,去噪声这一类问题里并不适合,因为缩放图片会丢失原来的信息,在测试集会表现得很差,而在图像分类这一类问题中可以使用。

batch_size 大小怎么确定?

一方面根据训练需要来定,一方面根据硬件配置来定。

  1. 一般 batch_size 大,网络会相对收敛快一些;
  2. 显卡配置不够就设置小一点,比如我只有 30;
  3. 另外看过一篇论文专门探究 batch_size 的适合大小,建议好像是 0~32 之间,但我个人看一些相关源码基本都是 100 左右。
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听 风、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值