使用UNet进行图像分割(利用Pytorch搭建)
简述
这里介绍一下如何使用Pytorch搭建一个UNet的图像分割模型,并训练出效果,论文中的一些trick这里没有使用。
只包含简单的几个模块,并且大部分代码都有注释。
环境
平台:Windows
python版本:3.7
Pytorch版本:torch:1.3.0,torchvision:0.4.0
准备
在搭建模型之前,我们还需要做些准备工作,那就是搜集数据,这里我提供一份眼球毛细血管数据集和一份VOC2012数据集来进行训练,如果有特殊的需求还是要自己搜集数据。
VOC数据集是 PASCAL VOC 挑战赛这个比赛使用的数据,里面包含了 目标分类、目标检测、目标分割、姿态识别、行为分类 所需要的数据与标签,我这里只使用分割的部分就可以了。
VOC数据集下载地址:
官方: https://pjreddie.com/projects/pascal-voc-dataset-mirror/
或者:https://pan.baidu.com/s/1yfUILB185VvlgQ8bXk536w 提取码:geir
图像样式
原始图片:
![](https://i-blog.csdnimg.cn/blog_migrate/d99b9531ba6fc7f668a8d54a43affacd.jpeg)
![](https://i-blog.csdnimg.cn/blog_migrate/416c09ab49f92b07a28a826d3434e555.jpeg)
标签:
![](https://i-blog.csdnimg.cn/blog_migrate/f57193bb961ab9adada517f749b06409.png)
![](https://i-blog.csdnimg.cn/blog_migrate/2e0ad1b2b0d58eb759874353ed89cb7e.png)
毛细血管数据集下载地址:
https://pan.baidu.com/s/1C06ERcImDpXlTneTrVuXPg 提取码:vmq0
图片样式
原始图片:
![](https://i-blog.csdnimg.cn/blog_migrate/7f4b28a7312e34ce6ab74ad7dbe16d78.png)
标签:
![](https://i-blog.csdnimg.cn/blog_migrate/572aa1e6764f5cb09295a96e02a3cfcd.jpeg)
毛细血管数据集与VOC的用法差不多,就只是读取方式有点区别。
代码
总共三个py文件,文件名分别为:dataset(数据集)、unet(网络模型)、train(训练模块)
数据集
下载好数据集之后,解压完毕,可以看到这些文件:
![](https://i-blog.csdnimg.cn/blog_migrate/3732578c30bf7782b74ce2c409574e9a.png)
这里我们只使用 JPEGImages 和 SegmentationClass 下的图片来进行语义分割,总共2913张图片。
首先是数据集部分的代码,没有使用数据增强,VOC部分:
import os
import cv2
import torchvision
from torch.utils.data import Dataset
from torchvision.utils import save_image
# 简单的数据集,没有进行数据增强
class Datasets(Dataset):
def __init__(self, path):
self.path = path
# 语义分割需要的图片的加载进来,做标签,总共2913张图片
self.name = os.listdir(os.path.join(path, "SegmentationClass"))
self.trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
def __len__(self):
return len(self.name)
# 简单的正方形转换,把图片和标签转为正方形
# 图片会置于中央,两边会填充为黑色,不会失真
def __trans__(self, img, size):
# 图片的宽高
h, w = img.shape[0:2]
# 需要的尺寸
_w = _h = size
# 不改变图像的宽高比例
scale = min(_h / h, _w / w)
h = int(h * scale)
w = int(w * scale)
# 缩放图像
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
# 上下左右分别要扩展的像素数
top = (_h - h) // 2
left = (_w - w) // 2
bottom = _h - h - top
right = _w - w - left
# 生成一个新的填充过的图像,这里用纯黑色进行填充(0,0,0)
new_img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
return new_img
def __getitem__(self, index):
# 拿到的图片
name = self.name[index]
# 把标签名的格式改成jpg,与原始图片一致
name2jpg = name[:-3] + "jpg"
# 所有的原始图片和标签
img_path = [os.path.join(self.path, i) for i in ("JPEGImages", "SegmentationClass")]
# 读取原始图片和标签,并转RGB
img_o = cv2.imread(os.path.join(img_path[0], name2jpg))
img_l = cv2.imread(os.path.join(img_path[1], name))
img_o = cv2.cvtColor(img_o, cv2.COLOR_BGR2RGB)
img_l = cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB)
# 转成网络需要的正方形
img_o = self.__trans__(img_o,