Unet从0搭建
希望可以通过实操一遍Unet的搭建来熟悉深度学习代码的流程。
一、数据集的制作
个人认为搭建网络第一步是加载数据集,所以就从这里入手。
class chzdata(Dataset): #这里就是搭建自己的数据集,首先要让chzdata继承Dataset
#在这个里面要重写三个方法,分别是 初始化、长度、get(nn.Model才需要写super)
def __init__(self, path, transform): #这里是给数据集传入 路径和一个预先的 transform
self.path = path
self.names = os.listdir(self.path)
self.transform = transform
def __len__(self): #必须要写数据集的长度方法
return len(self.names)
def __getitem__(self, index): #这个是按索引取数据的方法
img_path = os.path.join(self.path, self.names[index]) #获取了路径加文件名字,就指定到了某一个具体的文件
img = Image.open(img_path) #首先得用PIl.Image打开图片
img = self.transform(img) #然后把图片按预处理转换
return img #返回这个处理好的图片
因为图像的尺寸各不相同,所有大部分时候需要先对图像的尺寸整一下,比如把一个长方形的图片整成以长边为边的正方形,多余的内容补上黑色。
处理的代码:
def long_pro_size(image,size=(256,256)):
temp=max(image.size) #获取长边
mask=Image.new('RGB',(temp,temp),(0,0,0)) # 新建一个幕布 大小就是长边的正方形
mask.paste(image,(0,0)) #把原来的图片粘贴到幕布上
mask=mask.resize(size) #将图片先扩展成正方形然后再resize为 256*256
return mask
if __name__ == '__main__':
img1=Image.open('./ants/0013035.jpg')
img2=long_pro_size(img1)
img2.save('./images/0.jpg')
原来的图片:W大于H,768512
处理之后:
正方形,256256
二、搭建网络
搭建好数据加载之后就是搭建网罗,一般是按照网络的结构图搭建,比如U-net的结构图如下:
这里先看卷积层:conv_block
注意网络中的初始化的参数是要在建立这个网络的时候传的参数,而要传播的时候要传forward的参数,例如
下面搭好的 Conv_block ,我在令一个网络的时候要这样用:
model=Conv_block(3,16)
而当我令好了这个网络,我要使用它的时候,就是要让它前向传播了。就要这样再传参数:
img1=()#假设这里已经有了一个图片的tensor [1,3,256,256]
out=model(img1)
搭建完整U-net网络的代码:
class Conv_block(nn.Module):
def __init__(self,in_channel,out_channel): #给定输入通道数和输出通道数
super(Conv_block, self).__init__() #搭建网络就要重写 super
self.layer=nn.Sequential(
nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=1,padding=1,padding_mode='reflect',bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout(0.3),
nn.LeakyReLU(),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout(0.3),
nn.LeakyReLU(),
)
def forward(self,x):
return self.layer(x)
class DownSample(nn.Module): # 下采样
def __init__(self,channel):
super(DownSample, self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
nn.BatchNorm2d(channel),
nn.LeakyReLU()
)
def forward(self,x):
x=self.layer(x)
return x
class UpSample(nn.Module): # 上采样
def __init__(self,channel):
super(UpSample, self).__init__()
self.layer=nn.Conv2d(channel,channel//2,1,1) #1*1的卷积核不会特征提取,只会起到降低通道数的作用
def forward(self,x,feature_map):
up=F.interpolate(x,scale_factor=2,mode='nearest') #插值
out=self.layer(up)
return torch.cat((out,feature_map),dim=1)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.c1=Conv_block(3,64)
self.down1=DownSample(64)
self.c2=Conv_block(64,128)
self.down2=DownSample(128)
self.c3=Conv_block(128,256)
self.d3=DownSample(256)
self.c4=Conv_block(256,512)
self.d4=DownSample(512)
self.c5=Conv_block(512,1024)
self.u1=UpSample(1024)
self.c6=Conv_block(1024,512)
self.u2=UpSample(512)
self.c7=Conv_block(512,256)
self.u3=UpSample(256)
self.c8=Conv_block(256,128)
self.u4=UpSample(128)
self.c9=Conv_block(128,64)
self.out=nn.Conv2d(64,3,3,1,1)
self.th=nn.Sigmoid()
def forward(self,x): #(层级与上图中的U-net结构图一致)
r1=self.c1(x)
r2=self.down1(r1)
r3=self.c2(r2)
r4=self.down2(r3)
r5=self.c3(r4)
r6=self.d3(r5)
r7=self.c4(r6)
r8=self.d4(r7)
r9=self.c5(r8)
r10=self.u1(r9,r7)
r11=self.c6(r10)
r12=self.u2(r11,r5)
r13=self.c7(r12)
r14=self.u3(r13,r3)
r15=self.c8(r14)
r16=self.u4(r15,r1)
r17=self.c9(r16)
r18=self.out(r17)
r19=self.th(r18)
return r19