U-Net搭建

Unet从0搭建

希望可以通过实操一遍Unet的搭建来熟悉深度学习代码的流程。

一、数据集的制作

个人认为搭建网络第一步是加载数据集,所以就从这里入手。

class chzdata(Dataset):  #这里就是搭建自己的数据集,首先要让chzdata继承Dataset
	#在这个里面要重写三个方法,分别是 初始化、长度、get(nn.Model才需要写super)
    def __init__(self, path, transform): #这里是给数据集传入 路径和一个预先的 transform
        self.path = path
        self.names = os.listdir(self.path)
        self.transform = transform

    def __len__(self):   #必须要写数据集的长度方法

        return len(self.names) 

    def __getitem__(self, index):  #这个是按索引取数据的方法
        img_path = os.path.join(self.path, self.names[index]) #获取了路径加文件名字,就指定到了某一个具体的文件
        img = Image.open(img_path) #首先得用PIl.Image打开图片
        img = self.transform(img) #然后把图片按预处理转换
        return img  #返回这个处理好的图片

因为图像的尺寸各不相同,所有大部分时候需要先对图像的尺寸整一下,比如把一个长方形的图片整成以长边为边的正方形,多余的内容补上黑色。
处理的代码:

def long_pro_size(image,size=(256,256)):
    temp=max(image.size) #获取长边
    mask=Image.new('RGB',(temp,temp),(0,0,0)) # 新建一个幕布 大小就是长边的正方形
    mask.paste(image,(0,0)) #把原来的图片粘贴到幕布上
    mask=mask.resize(size)  #将图片先扩展成正方形然后再resize为 256*256

    return mask

if __name__ == '__main__':
    img1=Image.open('./ants/0013035.jpg')
    img2=long_pro_size(img1)
    img2.save('./images/0.jpg')

原来的图片:W大于H,768512在这里插入图片描述
处理之后:
正方形,256
256在这里插入图片描述

二、搭建网络

搭建好数据加载之后就是搭建网罗,一般是按照网络的结构图搭建,比如U-net的结构图如下:
在这里插入图片描述
这里先看卷积层:conv_block
注意网络中的初始化的参数是要在建立这个网络的时候传的参数,而要传播的时候要传forward的参数,例如
下面搭好的 Conv_block ,我在令一个网络的时候要这样用:

model=Conv_block(3,16)

而当我令好了这个网络,我要使用它的时候,就是要让它前向传播了。就要这样再传参数:

img1=()#假设这里已经有了一个图片的tensor [1,3,256,256]
out=model(img1)

搭建完整U-net网络的代码:

class Conv_block(nn.Module):
    def __init__(self,in_channel,out_channel): #给定输入通道数和输出通道数
        super(Conv_block, self).__init__()  #搭建网络就要重写 super
        self.layer=nn.Sequential(  
            nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=1,padding=1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout(0.3),
            nn.LeakyReLU(),

            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout(0.3),
            nn.LeakyReLU(),
        )

    def forward(self,x):
        return self.layer(x)

class DownSample(nn.Module):  #  下采样
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        x=self.layer(x)
        return x

class UpSample(nn.Module):   #  上采样
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1)  #1*1的卷积核不会特征提取,只会起到降低通道数的作用
    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest')  #插值
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1=Conv_block(3,64)
        self.down1=DownSample(64)
        self.c2=Conv_block(64,128)
        self.down2=DownSample(128)
        self.c3=Conv_block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_block(1024,512)
        self.u2=UpSample(512)
        self.c7=Conv_block(512,256)
        self.u3=UpSample(256)
        self.c8=Conv_block(256,128)
        self.u4=UpSample(128)
        self.c9=Conv_block(128,64)
        self.out=nn.Conv2d(64,3,3,1,1)
        self.th=nn.Sigmoid()

    def forward(self,x):  #(层级与上图中的U-net结构图一致)
        r1=self.c1(x)
        r2=self.down1(r1)
        r3=self.c2(r2)
        r4=self.down2(r3)
        r5=self.c3(r4)
        r6=self.d3(r5)
        r7=self.c4(r6)
        r8=self.d4(r7)
        r9=self.c5(r8)
        r10=self.u1(r9,r7)
        r11=self.c6(r10)
        r12=self.u2(r11,r5)
        r13=self.c7(r12)
        r14=self.u3(r13,r3)
        r15=self.c8(r14)
        r16=self.u4(r15,r1)
        r17=self.c9(r16)
        r18=self.out(r17)
        r19=self.th(r18)

        return r19  
  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值