【22】Unet网络的复现和理解

【1】网络结构

UNet网络模型图

 

Unet包括两部分:

1  特征提取部分,每经过一个池化层就一个尺度,包括原图尺度一共有5个尺度。

2  上采样部分,每上采样一次,就和特征提取部分对应的通道数相同尺度融合,但是融合之前要将其crop。这里的融合也是拼接。

该网络由收缩路径(contracting path)和扩张路径(expanding path)组成。其中,收缩路径用于获取上下文信

【1.1】网络优点

(1) overlap-tile策略

(2)数据增强(data augmentation)

(3)加权loss

【1.2】网络缺点

U-Net++作者分析U-Net不足并如何做改进:https://zhuanlan.zhihu.com/p/44958351

参考文献:https://zhuanlan.zhihu.com/p/118540575

【2】网络训练

代码以及权重下载地址:https://github.com/JavisPeng/u_net_liver

data and trained weight link: https://pan.baidu.com/s/1dgGnsfoSmL1lbOUwyItp6w code: 17yr

all dataset you can access from: https://competitions.codalab.org/competitions/15595

【2.1】代码展示

文件夹介绍

(1)data文件夹中放的是训练和测试的图片

(2)dowoload是下载的权重文件

Unetmodel.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   Unetmodel.py
@Time    :   2021/03/23 20:09:25
@Author  :   Jian Song 
@Contact :   1248975661@qq.com
@Desc    :   None
'''
# here put the import lib

import torch.nn as nn
import torch
from torch import autograd

'''
文件介绍:定义了unet网络模型,
******pytorch定义网络只需要定义模型的具体参数,不需要将数据作为输入定义到网络中。
仅需要在使用时实例化这个网络,然后将数据输入。
******tensorflow定义网络时则需要将输入张量输入到模型中,即用占位符完成输入数据的输入。
'''
#把常用的2个卷积操作简单封装下
class DoubleConv(nn.Module):
    #通过此处卷积,特征图的大小减4,但是通道数保持不变;
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            #添加了BN层
            nn.BatchNorm2d(out_ch), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        #定义网络模型
        #下采样-》编码
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)#反卷积
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    #定义网络前向传播过程
    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        #上采样
        up_6 = self.up6(c5)
        #cat函数讲解:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html
        merge6 = torch.cat([up_6, c4], dim=1)#此处横着拼接,dim=1表示在行的后面添加上原有矩阵
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

if __name__ == '__main__': 
    myUnet=Unet(1,1)
    print(myUnet)

main.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-

'''
(1)参考文献:UNet网络简单实现
https://blog.csdn.net/jiangpeng59/article/details/80189889

(2)FCN和unet的区别
https://zhuanlan.zhihu.com/p/118540575

'''
import torch
import argparse
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from  Unetmodel import Unet
from  setdata import LiverDataset
from  setdata import *


# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#定义输入数据的预处理模式,因为分为原始图片和研磨图像,所以也分为两种
#image转换为0~1的数据类型
x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# mask只需要转换为tensor
y_transforms = transforms.ToTensor()

def train_model(model, criterion, optimizer, dataload, num_epochs=5):
    for epoch in range(num_epochs):
        #.format参考,https://blog.csdn.net/u012149181/article/details/78965472
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1

            #判断是否调用GPU
            inputs = x.to(device)
            labels = y.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels) #计算损失值
            loss.backward()
            optimizer.step()

            #item()是得到一个元素张量里面的元素值
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
    #保存模型
    torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model

#训练模型
def train(batch_size):
    #模型初始化
    model = Unet(3, 1).to(device)
    batch_size = batch_size
    #定义损失函数
    criterion = nn.BCEWithLogitsLoss()
    #定义优化器
    optimizer = optim.Adam(model.parameters())
    #加载训练数据
    liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)

#模型的测试结果
def test(ckptpath):
    model = Unet(3, 1)
    model.load_state_dict(torch.load(ckptpath,map_location='cpu'))
    liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
    #一次加载一张图像
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    #eval函数是将字符串转化为list、dict、tuple,但是字符串里的字符必须是标准的格式,不然会出错
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()# 打开交互模式
    with torch.no_grad():
        for x, _ in dataloaders:
            y=model(x).sigmoid()
            #a.squeeze(i)   压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()


# def trainmodel(batchsize):
#     train(batchsize)

# def testmodel(ckptpath):
#     test(ckptpath)

if __name__ == '__main__':
    #参数解析
    # parse=argparse.ArgumentParser()
    # parse = argparse.ArgumentParser()
    # parse.add_argument("action", type=str, help="train or test")
    # parse.add_argument("--batch_size", type=int, default=8)
    # parse.add_argument("--ckpt", type=str, help="the path of model weight file")
    # args = parse.parse_args()
    # if args.action=="train":
    #     train(args)
    # elif args.action=="test":
    #     test(args)
    batchsize=10
    train(batchsize)
    ckptpath='./dowoload/weights_19.pth'
    test(ckptpath)

setdata.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from torch.utils.data import Dataset
import PIL.Image as Image
import os

#创建一个列表,存放图像和研磨图像的图像路径
def make_dataset(root):
    imgs=[]
    n=len(os.listdir(root))//2
    for i in range(n):        
        '''
        %3d--可以指定宽度,不足的左边补空格
        %-3d--左对齐
        %03d---一种左边补0 的等宽格式,比如数字12,%03d出来就是: 012
        '''
        #img=root+00i.png
        #mask=root+00i_mask.png
        img=os.path.join(root,"%03d.png"%i) 
        mask=os.path.join(root,"%03d_mask.png"%i)
        imgs.append((img,mask))
    return imgs


class LiverDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform                    #原始图像的预处理
        self.target_transform = target_transform      #研磨图像的预处理

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:                 #若设置了预处理
            img_x = self.transform(img_x)
        if self.target_transform is not None:           #若设置了预处理
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

【3】UNet模型参数展示

Unet(
  (conv1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv5): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up6): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
  (conv6): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up7): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (conv7): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up8): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
  (conv8): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up9): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (conv9): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (conv10): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
PS F:\PytorchTest\torchdeeplearnmodel\Unet> & G:/Anaconda3/envs/tensorflow/python.exe f:/PytorchTest/torchdeeplearnmodel/Unet/Unetmodel.py
Unet(
  (conv1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv5): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up6): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
  (conv6): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up7): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (conv7): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up8): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
  (conv8): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (up9): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (conv9): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (conv10): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

【4】参考文献

(1)pytorch笔记:05)UNet网络简单实现

(2)UNet网络简单实现

(3)FCN和unet的区别

 

  • 9
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
好的,我可以为您提供一个简单的Unet网络的实现,以下是代码: ```python import torch import torch.nn as nn # 定义编码器 class Encoder(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, batch_norm=True): super().__init__() if batch_norm: self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) else: self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv(x) return x # 定义解码器 class Decoder(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, batch_norm=True): super().__init__() if batch_norm: self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) else: self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv(x) return x # 定义UNet网络 class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1): super().__init__() # 编码器部分 self.enc1 = Encoder(in_channels, 64, batch_norm=False) self.pool1 = nn.MaxPool2d(2, 2) self.enc2 = Encoder(64, 128) self.pool2 = nn.MaxPool2d(2, 2) self.enc3 = Encoder(128, 256) self.pool3 = nn.MaxPool2d(2, 2) self.enc4 = Encoder(256, 512) self.pool4 = nn.MaxPool2d(2, 2) # 解码器部分 self.dec4 = Decoder(512, 256) self.up4 = nn.ConvTranspose2d(256, 256, 2, stride=2) self.dec3 = Decoder(256 + 256, 128) self.up3 = nn.ConvTranspose2d(128, 128, 2, stride=2) self.dec2 = Decoder(128 + 128, 64) self.up2 = nn.ConvTranspose2d(64, 64, 2, stride=2) self.dec1 = Decoder(64 + 64, out_channels, batch_norm=False) def forward(self, x): # 编码器部分 enc1 = self.enc1(x) enc2 = self.enc2(self.pool1(enc1)) enc3 = self.enc3(self.pool2(enc2)) enc4 = self.enc4(self.pool3(enc3)) # 解码器部分 dec4 = self.dec4(self.pool4(enc4)) up4 = self.up4(dec4) dec3 = self.dec3(torch.cat([up4, enc3], dim=1)) up3 = self.up3(dec3) dec2 = self.dec2(torch.cat([up3, enc2], dim=1)) up2 = self.up2(dec2) dec1 = self.dec1(torch.cat([up2, enc1], dim=1)) return dec1 ``` 这个实现是一个简单的UNet网络,包括编码器和解码器部分,您可以根据您的需求调整网络的深度和宽度。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值