提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
1.概述
UNet是医学图像分割领域经典的论文,因其结构像字母U
得名,本文的代码是对其他博主代码的细节上的增改,增加了测试代码。
下面是该博主的链接,包含了预训练模型:
UNet的Pytorch实现_Natuski_的博客-CSDN博客_pytorch unet
一、dataset.py
import os
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import torch
class SEGData(Dataset):
def __init__(self,path1,path2):
'''
根据标注文件去取图片
'''
self.img_path=path1
self.label_path=path2
self.images = sorted(os.listdir(self.img_path))
self.labels = sorted(os.listdir(self.label_path))
# self.label_data=os.listdir(self.label_path)
self.totensor=torchvision.transforms.ToTensor()
# 一般而言,尺寸越大,训练效果越好,速度越慢
self.resizer=torchvision.transforms.Resize((512,512))
def __len__(self):
return len(self.images)
def __getitem__(self, i):
'''
由于输出的图片的尺寸不同,我们需要转换为相同大小的图片。首先转换为正方形图片,然后缩放的同样尺度(256*256)。
否则dataloader会报错。
'''
# 取出图片路径
img = Image.open(self.img_path + self.images[i])
label = Image.open(self.label_path + self.labels[i])
# img_name = os.path.join(self.label_path, self.label_data[item])
# img_name = os.path.split(img_name)
# img_name = img_name[-1]
# img_name = img_name.split('.')
# img_name = img_name[0] + '.png'
# img_data = os.path.join(self.img_path, img_name)
# label_data = os.path.join(self.label_path, self.label_data[item])
# 将图片和标签都转为正方形
# img = Image.open(img_data)
# label = Image.open(label_data)
w, h = img.size
# 以最长边为基准,生成全0正方形矩阵
slide = max(h, w)
black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
black_label.paste(label, (0, 0, int(w), int(h)))
# 变为tensor,转换为统一大小512*512
img = self.resizer(black_img)
label = self.resizer(black_label)
img = self.totensor(img)
label = self.totensor(label)
return img,label
二、Model.py
from __future__ import print_function, division
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
#下采样
self.d1=DownsampleLayer(3,out_channels[0])#3-64
self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
#上采样
self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
#输出
self.o=nn.Sequential(
nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channels[0]),
nn.ReLU(),
nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels[0]),
nn.ReLU(),
nn.Conv2d(out_channels[0],3,3,1,1),
nn.Sigmoid(),
# BCELoss
)
def forward(self,x):
out_1,out1=self.d1(x)
out_2,out2=self.d2(out1)
out_3,out3=self.d3(out2)
out_4,out4=self.d4(out3)
out5=self.u1(out4,out_4)
out6=self.u2(out5,out_3)
out7=self.u3(out6,out_2)
out8=self.u4(out7,out_1)
out=self.o(out8)
return out
# 下采样
class DownsampleLayer(nn.Module):
def __init__(self,in_ch,out_ch):
super(DownsampleLayer, self).__init__()
self.Conv_BN_ReLU_2=nn.Sequential(
nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
self.downsample=nn.Sequential(
nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
def forward(self,x):
"""
:param x:
:return: out输出到深层,out_2输入到下一层,
"""
out=self.Conv_BN_ReLU_2(x)
out_2=self.downsample(out)
return out,out_2
# 上采样
class UpSampleLayer(nn.Module):
def __init__(self,in_ch,out_ch):
# 512-1024-512
# 1024-512-256
# 512-256-128
# 256-128-64
super(UpSampleLayer, self).__init__()
self.Conv_BN_ReLU_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU()
)
self.upsample=nn.Sequential(
nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
def forward(self,x,out):
'''
:param x: 输入卷积层
:param out:与上采样层进行cat
:return:
'''
x_out=self.Conv_BN_ReLU_2(x)
x_out=self.upsample(x_out)
cat_out=torch.cat((x_out,out),dim=1)
return cat_out
三、train.py
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
import os
from torchvision.utils import save_image
from min_unet.Model import UNet
from min_unet.dataset import SEGData
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def main(path1,path2,EPOCH,Batch):
net = UNet().cuda()
optimizer = torch.optim.Adam(net.parameters())
loss_func = nn.BCELoss()
data = SEGData(path1,path2)
dataloader = DataLoader(data, batch_size=Batch, shuffle=True, num_workers=0, drop_last=True)
summary = SummaryWriter(r'Log')
print('load net')
net.load_state_dict(torch.load('SAVE/Unet.pt'))
print('load success')
for epoch in range(EPOCH):
print('开始第{}轮'.format(epoch))
net.train()
for i, (img, label) in enumerate(dataloader):
img = img.cuda()
label = label.cuda()
img_out = net(img)
loss = loss_func(img_out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
summary.add_scalar('bceloss', loss, i)
torch.save(net.state_dict(), r'SAVE/Unet.pt')
img, label = data[2]
img = torch.unsqueeze(img, dim=0).cuda()
net.eval()
out = net(img)
if not os.path.exists(r"Log_imgs"):
os.mkdir(r"Log_imgs")
if epoch%10==0:
save_image(out, 'Log_imgs/segimg_{}——.png'.format(epoch, i), nrow=1, scale_each=True)
print(f"第{epoch}轮train_loss={loss.item()}")
print('第{}轮结束'.format(epoch))
if __name__=='__main__':
path1 = r'../data/imgs/'#训练集图像
path2 = r'../data/masks/'#训练集图像标签
EPOCH = 11
Batch = 2
main(path1,path2,EPOCH,Batch)
四、test.py
import torch
import torchvision
import os
from torchvision.utils import save_image
from min_unet.Model import UNet
from PIL import Image
def test(input_path):
net = UNet().cuda()
weight=r'SAVE/Unet.pt'
if os.path.exists(weight):
net.load_state_dict(torch.load(weight))
print("successful")
else:
print("no")
if not os.path.exists(r"Test_imgs"):
os.mkdir(r"Test_imgs")
for file in os.listdir(input_path):
f=file.split('.')[0]
path = os.path.join(input_path,file)
img = Image.open(path)
w, h = img.size
slide = max(h, w)
# img=transform(path)
black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
# black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
# black_label.paste(label, (0, 0, int(w), int(h)))
tensor_test = torchvision.transforms.ToTensor()
image = tensor_test(black_img)
img = torch.unsqueeze(image, dim=0).cuda()
net.eval()
out = net(img)
save_image(out, f'Test_imgs/segimg_{f}.png', nrow=1, scale_each=True)
if __name__=="__main__":
input_path = r"../data/test_imgs"
test(input_path)