概述
attention-unet主要的贡献就是提出了attention gate,它即插即用,可以直接集成到unet模型当中,作用在于抑制输入图像中的不相关区域,同时突出特定局部区域的显著特征,并且它用soft-attention 代替hard-attention,所以attention权重可以由网络学习,并且不需要额外的label,只增加少量的计算量。
细节
结构
核心还是unet的结构,但是在做skip-connection的时候,中间加了一个attention gate,经过这个ag之后,再进行concat操作。因为encoder中的细粒度信息相对多一点,但是很多是不需要的冗余的,ag相当于是对encoder的当前层进行了一个过滤,抑制图像中的无关信息,突出局部的重要特征。
attention gate
两个输入分别是encoder的当前层
x
l
x^l
xl和decoder的下一层
g
g
g,他们经过1x1的卷积(将通道数变为一致之后),再做逐元素的相加,然后经过relu,1x1的卷积(将通道数降为1)和sigmoid得到注意力系数,然后再经过一个resample模块将尺寸还原回来,最后就可以使用注意力系数对特征图进行加权了。
注:这里是3D的,2D理解的话,直接去掉最后一个维度就好了。
一些解释
:
为什么要两个输入做加法而不是直接根据encoder的当前层得到注意力系数呢?
可能是因为,首先处理完成之后的两张相同尺寸和通道数的特征图,提取的特征是不同的。那么这么操作能够使相同的感兴趣区域的信号加强,同时各自不同的区域也能作为辅助,两份加起来辅助信息也会更多。或者说是对核心信息的进一步强调,同时又不忽视那些细节信息。
为什么需要resample呢?
因为
x
l
与
g
x^l与g
xl与g的尺寸是不相同的,显然
g
g
g的尺寸是
x
l
x^l
xl的一半,他们不可能进行逐元素的相加,所以需要将两个尺寸变得一致,要么大的下采样要么小的上采样,实验出来是大的下采样效果好。但是这个操作之后得到的就是注意力系数了,要和
x
l
x^l
xl做加权肯定要尺寸相同的,所以还得重新上采样。
attention
Attention函数的本质可以被描述为一个查询(query)到一系列(键key-值value)对的映射
在计算attention时主要分为三步:
- 第一步是将query和每个key进行相似度计算得到权重,常用的相似度函数有点积,拼接,感知机等;
- 第二步一般是使用一个softmax函数对这些权重进行归一化;
- 最后将权重和相应的键值value进行加权求和得到最后的attention。
hard-attention
:一次选择一个图像的一个区域作为注意力,设成1,其他设为0。他是不能微分的,无法进行标准的反向传播,因此需要蒙特卡洛采样来计算各个反向传播阶段的精度。 考虑到精度取决于采样的完成程度,因此需要其 他技术(例如强化学习)。
soft-attention
:加权图像的每个像素。 高相关性区域乘以较大的权重,而低相关性区域标记为较小的权重。权重范围是(0-1)。他是可微的,可以正常进行反向传播。
简单实现
import paddle
import paddle.nn as nn
# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Layer):
def __init__(self,in_channels,out_channels):
super(VGGBlock, self).__init__()
self.layer=nn.Sequential(
nn.Conv2D(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU(),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
# 先让encoder当前层通过attention_gate
# 然后再将decoder当前层上采样并且和encoder当前层 做concat
# 反卷积(转置卷积)计算公式:
# 输出大小 = (输入大小 − 1) * Stride + Filter - 2 * Padding
# 当前这种设置使得输入输出尺寸相同
class Up(nn.Layer):
def __init__(self,in_channels,out_channels):
super(Up, self).__init__()
self.layer=nn.Sequential(
nn.Conv2DTranspose(in_channels, out_channels, 4, 2, 1)
)
def forward(self,x1,x2,attention_gate):
x1=self.layer(x1)
x2=attention_gate(x1,x2)
# 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
return paddle.concat([x2,x1],axis=1)
# 软注意力机制 相当于是一个筛子 对encoder的当前层进行筛选
class AttentionGate(nn.Layer):
# 因为我们是将encoder的当前层 以及decoder的下一层上采样之后的结果送入Attention_Block的
# 所以他们的尺寸以及通道数适相同的
def __init__(self,in_channels,out_channels):
super(AttentionGate, self).__init__()
self.w_g=nn.Sequential(
nn.Conv2D(in_channels,out_channels,1,1,0),
nn.BatchNorm2D(out_channels)
)
self.w_x=nn.Sequential(
nn.Conv2D(in_channels, out_channels, 1, 1, 0),
nn.BatchNorm2D(out_channels)
)
self.relu=nn.LeakyReLU()
self.psi=nn.Sequential(
nn.Conv2D(out_channels, 1, 1, 1, 0),
nn.BatchNorm2D(1),
nn.Sigmoid()
)
def forward(self, g,x):
g1=self.w_g(g)
x1=self.w_x(x)
psi=self.relu(g1+x1)
psi=self.psi(psi)
return x*psi
class AttentionUNet(nn.Layer):
def __init__(self,num_classes=2):
super(AttentionUNet, self).__init__()
filters = [64, 128, 256, 512, 1024]
self.pool = nn.MaxPool2D(2)
## -------------encoder-------------
self.encoder_1 = VGGBlock(3, filters[0])
self.encoder_2 = VGGBlock(filters[0], filters[1])
self.encoder_3 = VGGBlock(filters[1], filters[2])
self.encoder_4 = VGGBlock(filters[2], filters[3])
self.encoder_5 = VGGBlock(filters[3], filters[4])
## -------------decoder-------------
self.up_4 = Up(filters[4], filters[3])
self.up_3 = Up(filters[3], filters[2])
self.up_2 = Up(filters[2], filters[1])
self.up_1 = Up(filters[1], filters[0])
self.decoder_4 = VGGBlock(filters[4], filters[3])
self.decoder_3 = VGGBlock(filters[3], filters[2])
self.decoder_2 = VGGBlock(filters[2], filters[1])
self.decoder_1 = VGGBlock(filters[1], filters[0])
self.attention_gate4=AttentionGate(512,256)
self.attention_gate3=AttentionGate(256,128)
self.attention_gate2=AttentionGate(128,64)
self.attention_gate1=AttentionGate(64,32)
self.final = nn.Sequential(
nn.Conv2D(64,num_classes,3,1,1),
)
def forward(self,x):
## -------------encoder-------------
encoder_1 = self.encoder_1(x)
encoder_2 = self.encoder_2(self.pool(encoder_1))
encoder_3 = self.encoder_3(self.pool(encoder_2))
encoder_4 = self.encoder_4(self.pool(encoder_3))
encoder_5 = self.encoder_5(self.pool(encoder_4))
## -------------decoder-------------
decoder_4 = self.up_4(encoder_5, encoder_4,self.attention_gate4)
decoder_4 = self.decoder_4(decoder_4)
decoder_3 = self.up_3(decoder_4, encoder_3,self.attention_gate3)
decoder_3 = self.decoder_3(decoder_3)
decoder_2 = self.up_2(decoder_3, encoder_2,self.attention_gate2)
decoder_2 = self.decoder_2(decoder_2)
decoder_1 = self.up_1(decoder_2, encoder_1,self.attention_gate1)
decoder_1 = self.decoder_1(decoder_1)
output = self.final(decoder_1)
return output
if __name__ == '__main__':
# x=paddle.randn(shape=[2,3,256,256])
unet=AttentionUNet()
# print(net(x).shape)
paddle.summary(unet, (1,3,256,256))