UNet网络学习记录

如下图所示,整个网络结构包括两部分,编码结构和解码结构,编码结构是对特征进行提取,解码结构是对特征进行还原;如下右图所示,这个步骤包括数据集的加载,网络的搭建,训练网络(调用网络)
在这里插入图片描述
网络的结构解析:
输入图片,进行两次卷积,进行一次下采样,重复4次
最下面1层开始,经过两次卷积,再上采样
从图中可以看出,上采样后,通道数发生了改变,这里使用了1x1的卷积
在每次的上采样之后,都与对应的下采样部分进行cat,进行拼接
在整个网络的搭建过程中,是通过模块化编程来进行实现,
比如分为:模块1定义卷积,其中包括两次卷积;模块2定义下采样;模块3定义上采样
data.py的代码如下所示:主要用于对训练集测试集的图像进行处理等

import os

from torch.utils.data import Dataset  #导入所需要的包(utils下的dataset)
from utils1 import *
from torchvision import transforms #这里的归一化在pytorch中封装成了包
#所有的图片都需要处理,归一化处理
transform=transforms.Compose([
    transforms.ToTensor
])   #这里仅仅使用了totensor

class MyDataset(Dataset): #定义一个自己数据集的类,继承自Dataset
    def __init__(self,path): #初始化,并传入地址
        self.path=path #把实际地址传入到变量上,下一步是获取文件所有的文件名
        self.name=os.listdir(os.path.join(path,'seg'))#这部分是使用os下的path指令拿到path路径并进行拼接(path/seg) #w外部的这个os是拿到每一个文件夹下的图片,因为文件放在多个文件夹

    def __len__(self): #返回文件名的数量,那也就是数据集的数量
        return len(self.name)

    def __getitem__(self, index): #这部分内容是数据集的制作,传入的是数据集的下标索引
        segment_name=self.name[index] #xx.png 这里是获取名字的下标
        segment_path=os.path.join(self.path,'seg',segment_name) #这里是拿到每个分割后图片的路径,数据集路径+文件夹路径+每张图片的名字
        image_path=os.path.join(self.path,'JPG',segment_name.replace('png','jpg')) #拿到原图的路径,并把原图png转为jpg这种格式,保持格式一致
        #输出的图片大小一般是不一致的,需要对图片进行缩放,这部分代码封装成了模块utils

        segment_image=keep_image_size_open(segment_path)  #这里是把分割图片统一大小
        image=keep_image_size_open(image_path) #把真实图片统一大小

        return transforms(image),transform(segment_image) #把分割图片和原图都进行归一化,转换为0-1的数,这里返回的是对应的一组图片

if __name__ == '__main__': #上面是定义函数,这里是主函数,程序的入口,如果直接运行这个代码,就会执行,是这些代码的入口
    data=MyDataset('给个地址')
    print(data[0][0].shape)  #这里是打印第0张原图的形状,设置的是3x256X256
        #注意:在计算机中data返回的数据类型是元组(不可以增删改)
        #Data[0][0].shape  打印出来的结果是[3,256,256]

注意:在计算机中data返回的数据类型是元组(不可以增删改)
Data[0][0].shape 打印出来的结果是[3,256,256].第一个0表示第0张,第二个0表示原图,第二个0更改为1就是分割的图片形状;第一个[]表示索引,也就是图片的序号,第二个[]表示pair中的第一个,即返回值的第一个
**utils.py部分的代码如下图所示:**这部分代码一般是定义用于图像处理的工具

#注意,utils表示的是一个工具类
#在这里表示的是对图片进行处理
from PIL import Image

def keep_image_size_open(path,size=(256,256)): #这部分统一图像大小,需要的参数是图像路径,缩放图片大小为256x256
    img=Image.open(path) #使用image打开图片并且传入到img变量
    temp=max(img.size) #读取图片的最长边
    mask=Image.new('RGB',(temp,temp),(0,0,0)) #做一个掩码,就是读取图片的最长边,做一个黑色的正方形
    mask.paste(img,(0,0)) #把数据的图片粘贴上去,从(0,0)开始
    mask=mask.resize(size) #把掩码图片缩放为想要的尺寸(256x256)
    return mask #返回mask

net.py代码如下图所示:这部分代码是网络的搭建过程

import torch
from torch import nn
from  torch.nn import functional as F

class Conv_Block(nn.Module):  #定义卷积模块,继承自己module
    def __int__(self,in_channel,out_channel): #初始化,需要传入输入和输出
        super(Conv_Block,self).__init__() #初始化父类方法
        self.layer=nn.Sequential( #定义层结构,使用Sequ实现连续性,即经过的每一步卷积过程
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),nn.Conv2d(out_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),

        )

    def forward(self,x): #定义前向传播,传入x
        return self.layer(x)

class DownSample(nn.Module): #定义下采样模块
    def __int__(self,channel):
        super(DownSample,self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

class UpSample(nn.Module): #定义上采样模块
    def __int__(self,channel):
        super(UpSample,self).__int__()
        self.layer=nn.Conv2d(channel.channel//2,1,1)

    def forward(self,x,feature_map): #外加拼接部分
        up=F.interpolate(x,scale_factor=2,mode='nearest')
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)

#上述的过程都是定义相关的模块,下面开始搭建网络,就是把模块给组合起来
class UNet(nn.Module):
    def __int__(self):
        super(UNet,self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)

        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out=nn.Conv2d(64,3,3,1,1)
        self.Th=nn.Sigmoid()

    def forward(self,x):
        R1=self.c1(x)
        R2=self.c2(self.d1(R1))
        R3 = self.c3(self.d1(R2))
        R4 = self.c4(self.d1(R3))
        R5 = self.c5(self.d1(R4))

        o1=self.c6(self.u1(R5,R4))
        o2 = self.c7(self.u2(o1, R3))
        o3 = self.c8(self.u3(o2, R2))
        o4 = self.c9(self.u4(o3, R1))

        return self.Th(self.out(o4))

if __name__ == '__main__':
    x=torch.randn(2,3,256,256)
    net=UNet()
    print(net(x).shape)

**train.py的代码如下图所示:**这部分代码主要是用来调用数据集,网络结构,定义训练优化器,损失函数,得到的结果以及可视化等等

from torch import nn,optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
import os
from torchvision.utils import save_image
from utils1 import *

device=torch.device('cuda'if torch.cuda.is_available()else'cup') #指定数据集
weight_path='params/unet.pth' #设置保存权重路径
data_path=r'数据集地址' #设置数据集路径
save_path='train_image'

if __name__ == '__main__':
    data_loader=DataLoader(MyDataset(data_path),batch_size=4,shuffle=True) #调用原来写的data处理方式模块
    net=UNet().to(device) #把网络放在设备上
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))#如果权重存在打印加载成功
        print('sucessful load weight!')
    else:
        print('not sucessful load weight')

    opt=optim.Adam(net.parameters()) #定义优化器,选择ADAM,把网络的参数放进去
    loss_fun=nn.BCELoss() #定义损失计算方式

    epoch=1 #定义训练的轮数
    while True: #一直训练
        for i, (image, segment_image) in enumerate(data_loader):
            image,segment_image=image.to(device),segment_image.to(device) #把原图和分割图都放在设备上面

            out_image=net(image) #经过网络输出的图片为out_image
            train_loss=loss_fun(out_image,segment_image) #传入输出图和标签图,计算损失
            opt.zero_grad()#更新梯度
            train_loss.backward()#把损失反向传播,指导网络往损失小的方向下降
            opt.step()#使用优化器

            if i%5==0:#打印训练过程损失的变化
                print(f'训练的轮数:train_loss====>>{train_loss.item}')

            if i%50==0:#保存训练过程的参数
                torch.save(net.state_dict(),weight_path)#就是刚刚设置的保存路径

            #为了看训练过程中图的变化情况,对比原图+标签图+输出图,
            _image=image[0] #取第一张图作为,_image
            _segment_image=segment_image[0]
            _out_image=out_image[0]

            img=torch.stack([_image,_segment_image,_out_image],dim=0) #把得到的三张图进行拼接
            save_image(img,f'{save_path}/i.png’)

        epoch=epoch+1 #循环轮数
  • 8
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值