Pytorch入门——用UNet网络做图像分割

这篇博客介绍如何使用PyTorch从头实现UNet网络进行图像分割,通过阅读相关paper并参考其他博主的代码,作者逐步解析了网络结构和运行流程,同时提供了数据集链接和关键代码段。
摘要由CSDN通过智能技术生成

最近看的paper里的pytorch代码太复杂,我之前也没接触过pytorch,遂决定先自己实现一个基础的裸代码,这样走一遍,对跑网络的基本流程和一些常用的基础函数的印象会更深刻。

本文的代码和数据主要来自pytorch笔记:05)UNet网络简单实现_Javis486的专栏-CSDN博客

附上该博主的github地址:https://github.com/JavisPeng/u_net_liver

并在自己的理解的基础上做了一些改动,以及加了大量注释。

如有错误,欢迎指出。

 unet.py(实现unet网络)

import torch.nn as nn
import torch

class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True),
                nn.Conv2d(out_ch,out_ch,3,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True)  
            )
    def forward(self,x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(UNet,self).__init__()
        self.conv1 = DoubleConv(in_ch,64)
        self.pool1 = nn.MaxPool2d(2)#每次把图像尺寸缩小一半
        self.conv2 = DoubleConv(64,128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128,256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256,512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512,1024)
        
以下是使用PyTorch实现Unet图像分割的基本步骤: 1. 准备数据集:包括训练集和验证集,每个样本包含输入图像和相应的标签图像。 2. 定义Unet模型:Unet模型是一种编码器-解码器结构,其中编码器部分由卷积和池化操作组成,解码器部分由卷积和反卷积操作组成。在PyTorch中,可以使用nn.Module类定义模型。 3. 定义损失函数:常用的图像分割损失函数包括交叉熵损失函数和Dice损失函数。在PyTorch中,可以使用torch.nn.functional中的函数定义损失函数。 4. 定义优化器:常用的优化器包括随机梯度下降(SGD)和Adam优化器。在PyTorch中,可以使用torch.optim中的类定义优化器。 5. 训练模型:使用上述定义的损失函数和优化器,以及训练集数据训练模型。 6. 验证模型:使用验证集数据验证模型的性能。 7. 预测结果:使用训练好的模型对新的图像进行分割。 这里是一个简单的示例代码,可以帮助你开始实现Unet图像分割: ```python import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Unet(nn.Module): def __init__(self): super(Unet, self).__init__() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 128, 3, padding=1) self.conv3 = nn.Conv2d(128, 256, 3, padding=1) self.conv4 = nn.Conv2d(256, 512, 3, padding=1) self.conv5 = nn.Conv2d(512, 1024, 3, padding=1) self.upconv6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv7 = nn.Conv2d(1024, 512, 3, padding=1) self.upconv8 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv9 = nn.Conv2d(512, 256, 3, padding=1) self.upconv10 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv11 = nn.Conv2d(256, 128, 3, padding=1) self.upconv12 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv13 = nn.Conv2d(128, 64, 3, padding=1) self.conv14 = nn.Conv2d(64, 1, 1) def forward(self, x): x1 = F.relu(self.conv1(x)) x2 = F.relu(self.conv2(F.max_pool2d(x1, 2))) x3 = F.relu(self.conv3(F.max_pool2d(x2, 2))) x4 = F.relu(self.conv4(F.max_pool2d(x3, 2))) x5 = F.relu(self.conv5(F.max_pool2d(x4, 2))) x6 = F.relu(self.conv7(torch.cat([self.upconv6(x5), x4], 1))) x7 = F.relu(self.conv9(torch.cat([self.upconv8(x6), x3], 1))) x8 = F.relu(self.conv11(torch.cat([self.upconv10(x7), x2], 1))) x9 = F.relu(self.conv13(torch.cat([self.upconv12(x8), x1], 1))) x10 = self.conv14(x9) return x10 model = Unet() criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = model(images) predicted = torch.round(torch.sigmoid(outputs)) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the test images: %d %%' % (100 * correct / total)) ``` 需要注意的是,这只是一个简单的示例代码,需要根据具体任务进行修改和优化。
评论 93
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值