从零手写Unet代码及数据

在这里插入图片描述
数据集资源已上传,不需要积分(可能还在审核)。
下载链接

一.网络模型(model.py)

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv, self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        return self.conv(x)



class Unet(nn.Module):
    def __init__(self,in_channels=3,out_channels=1,features=[64,128,256,512],):
        super(Unet,self).__init__()
        self.ups = nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels=feature


        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2,)
            )
            self.ups.append(DoubleConv(feature*2,feature))


        self.bottleneck=DoubleConv(features[-1],features[-1]*2)
        self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)


    def forward(self,x):
        skip_connections=[]

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x=self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0,len(self.ups),2):
            x=self.ups[idx](x)
            skip_connection=skip_connections[idx//2]</
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值