背景
在研究使用深度学习的方法进诸如图像去噪、图像去雾、图像去马赛克等需求的时候,经常是受污染图片和干净图片互相对应,此时就需要制作这种成对应关系的数据集。本文使用的读取图片的方法是 PIL 库里的 Image.open()
,初步学习,不对的地方还请指正。
简单步骤
- 获取两个文件夹下的所有图片(建议图片名称对应完全一样,且没有汉字);
- 保持对应关系分别获取每个图片的具体路径;
- 读取图片数据;
- 进行必要的变换操作(根据读取图片的方法不同有异);
- 返回受污染图片和相应的标签图片。
- 开始炼丹。
代码实现
- 以下代码在
DataTrain.py
里,功能是对图片数据进行读入和一定操作,方便后续使用torch.utils.data.DataLoader()
方法获取数据集。
import torchvision
from torch.utils.data import Dataset
import os
from PIL import Image
class MyDataset(Dataset): # 继承 Dataset 类
def __init__(self, input_path, label_path):
self.input_path = input_path # 受污染图片所在文件夹
self.input_path_image = os.listdir(input_path) # 文件夹下的所有图片对象
self.label_path = label_path # 干净图片所在文件夹
self.label_path_image = os.listdir(label_path)
# 定义要对图片进行的变换
self.transforms = torchvision.transforms.Compose([
# 中心裁剪64*64大小作为pacth
torchvision.transforms.CenterCrop([64, 64]),
# 将读入的数据归一化[0, 1]之间并变为张量类型
torchvision.transforms.ToTensor(),
])
def __len__(self):
return len(self.input_path_image) # 返回长度
def __getitem__(self, index):
# index 索引对应的受污染图片完整路径
input_image_path = os.path.join(self.input_path, self.input_path_image[index])
# 利用PIL.Image 读入图片数据并转换通道结构
input_image = Image.open(input_image_path).convert('RGB')
label_image_path = os.path.join(self.label_path, self.label_path_image[index])
label_image = Image.open(label_image_path).convert('RGB')
# 对读入的图片进行固定的变换
input = self.transforms(input_image)
label = self.transforms(label_image)
return (input, label) # 返回适合在网络中训练的图片数据
- 准备好图片数据之后就可以准备送入网络进行训练了,以下代码在
train.py
里用来获取数据集用作训练。
from DataTrain import MyDataset # 从前述 py 文件里导入 MyDataset 类
from torch.utils.data import DataLoader
from torch.autograd import Variable
BATCH_SIZE = 100 # 参与每次训练的数量,就是将数据集按照 BATCH_SIZE 大小进行拆分
dataset = MyDataset(input_path, label_path) # 实现前述定义的 MyDataset 类
data_train = DataLoader(dataset, batch_size=BATCH_SIZE,shuffle=True) # 获取训练数据集,拆分,打乱
# 这个循环一般在 epoch 循环下表示一次训练,x,y 对应前述返回的 input,label
for j, (x, y) in enumerate(data_train):
input = Variable(x).cuda() # 转换数据为 GPU 变量
label = Variable(y).cuda()
# 接下来 input,label 就可以送入网络进行训练了
其他
PIL 库读取图片
- 图片格式:
PIL 库读取图片后本就是 RGB 形式,但有时训练时会报错,具体原因不清楚,故获取图片后一般再转成.convert('RGB')
一下。 - 数据格式
PIL库读取图片数据是 uint8 类型的,不适合直接用作网络训练,要进行归一化转为 float 类型。
无法从其他 py 文件里导入
在工程文件下新建一个 python package,将需要导入的 .py 文件放进去,之后,右击该包,如下操作即可。
为什么CenterCrop()
,而不是Resize()
?
在我阅读的大部分源码中,很多人都进行的是Resize()
操作,我认为在图像去雨,去噪声这一类问题里并不适合,因为缩放图片会丢失原来的信息,在测试集会表现得很差,而在图像分类这一类问题中可以使用。
batch_size 大小怎么确定?
一方面根据训练需要来定,一方面根据硬件配置来定。
- 一般 batch_size 大,网络会相对收敛快一些;
- 显卡配置不够就设置小一点,比如我只有 30;
- 另外看过一篇论文专门探究 batch_size 的适合大小,建议好像是 0~32 之间,但我个人看一些相关源码基本都是 100 左右。