本专栏内容是学习 深度学习麋了鹿 的《图像分割UNet硬核讲解》(带你手撸unet代码)部分笔记。
内容包括从数据集→网络结构→训练→测试。(附代码)
本节是 UNet 数据集制作及代码实现 笔记。
首先我们先看一下 UNet 的网络结构:
Unet网络非常的简单,前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大些的英文字母U,所以叫做 U-net 。(第二部分再细讲)
接下来,我们尝试制作一个数据集:
1.VCO2007
(注意:麋了鹿博主在这方面并没有细说,所以我只能按照我自己的理解)
在开始之前,我们先准备好一个数据集:VCO2007 网上有下载,里面包含三部分:devkitVOCdevkit_08-Jun-2007.jar
、测试集 VOCtest_06-Nov-2007.jar
、训练集/验证集VOCtrainval_06-Nov-2007.jar
,将它们分别解压得到一个总的文件夹 VOCdevkit
,包含如下图所示的文件:
测试集 VOCtest_06-Nov-2007.jar
和训练集/验证集 VOCtrainval_06-Nov-2007.jar
中的文件均为一个 VOC2007 文件夹,将里面的数据进行整合,打开 VOC2007 文件:
红色方框框住的三个文件里面的数据是测试集 VOCtest_06-Nov-2007.jar
和训练集/验证集VOCtrainval_06-Nov-2007.jar的整合。
2.要点思想
准备好 VCO2007 后,我们需要明确知道:
① 我们网络输入是需要固定大小的,但是各个原图片的大小不同,所以我们要进行等比缩放。(为何需要等比缩放?因为直接缩放会变形)
等比缩放的思想:取每张图的最长边,做一个 mask 正矩形,然后将这张图贴在那张正矩形上面,然后就变成了一个以最长边为边长的正方形,接下里进行等比的 resize ,这样就不会变形了。
对应代码:utils01.py
② 我们的图片要转化成png才能用作U-net的训练集
3.代码
① utils01.py
from PIL import Image
def keep_image_size_open(path, size=(512, 512)): # 512可以更改
img = Image.open(path) # 读取图片
temp = max(img.size) # 获取图片的最长边
mask = Image.new('RGB', (temp, temp), (0, 0, 0)) # 做 mask 掩码,正方形,颜色全黑
mask.paste(img, (0, 0)) # 将原图粘贴上来,放在左上角
mask = mask.resize(size) # 对它进行resize,缩放到想要的,传入的那个size
return mask
② Data01.py
# 导入我们所需要的库
import os # 这是导入os模块到当前程序,我们想要获取所有图片的地址的时候就要用到它
from torch.utils.data import Dataset # 从torch的常用工具数据区引入 Dataset
from torchvision import transforms
from 图像分割UNet硬核讲解.utils01 import keep_image_size_open
transform = transforms.Compose([transforms.ToTensor()])
# 创建了一个类 MyData 然后让它继承了 Dataset
class MyDataset(Dataset):
# 进行初始化
def __init__(self, path):
"""这是在进行初始化. 比如说我们根据这个类去创建一个特定实例的时候,它就要运行一个函数,函数里面会为整个 class 提供全局变量,
以及为后续的函数提供一些量,例如 __getitem__方法和长度类 __len__ """
self.path = path # 这里的操作就是让某函数中一个变量变为这个类里面的全局变量
# os.listdir(): 列出路径下所有的文件
# os.path.join(): 连接文件的作用
self.name = os.listdir(os.path.join(self.path, 'SegmentationClass')) # 将这些所有地址弄成一个列表
def __len__(self):
return len(self.name) # 这是这个列表的长度
# 数据集的制作
def __getitem__(self, index):
"""def __getitem__(self, item): 中 item 是默认,但我们通常会改成idx,也即是index"""
segment_name = self.name[index] # 这是为了读取某一个图片的名称,注意此时图片格式是 .png
segment_path = os.path.join(self.path, 'SegmentationClass', segment_name)
image_path = os.path.join(self.path, 'JPEGImages', segment_name.replace('png', 'jpg')) # 将 .png 转换成 .jpg格式
# 将原图和标签都送进keep_image_size_open函数里面进行resize
segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open(image_path)
# 归一化
return transform(image), transform(segment_image)
if __name__ == '__main__':
data = MyDataset('D:\\STUDY1\\MILELU\\图像分割UNet硬核讲解\\VOC\\VOCdevkit\\VOC2007')
print(data[0][0].shape)
print(data[0][1].shape)
验证结果:
下节内容UNet 网络结构及代码实现。