空间注意力 PAM_Module:
- input A:(C,H,W),先分别通过三个卷积层得到B、C、D三个特征图,B reshape+transpose为(HxW,C),C reshape为(C,HxW),D reshape为(C,HxW)
- 然后B与C相乘得到(HxW,HxW),再经过softmax得到S
- 然后D与S做乘法,得到(C,HxW),再乘以尺度系数α,再reshape为(C,H,W)后与input相加得到输出
- 初始化α系数为0,注意这是可学习参数,逐渐学习得到更大的权重
举例:
code:
class PAM_Module(nn.Module):
# 空间注意力模块
def __init__(self , in_dim):
super(PAM_Module, self).__init__()
self.channel_in = in_dim
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim //8 , kernel_size =1 )
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim // 8 , kernel_size = 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size = 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim = -1)
def forward(self , x):
x = x.squeeze(-1)
print((x.shape))
m_batchsize , C,height , width = x.size()
print("m_batchsize :",m_batchsize , " C:",C , " heigth:",height , " width:" , width)
# permute:维度换位
# proj_query: (1,60,9,9) -> (1,7,9,9) -> (1,7,81) -> (1,81,7)
proj_query = self.query_conv(x).view(m_batchsize , -1 , width*height).permute(0,2,1)
print("proj_equery : " , proj_query.shape)
# proj_key: (1,60,9,9) -> (1,7,9,9) -> (1,7,81)
proj_key = self.key_conv(x).view(m_batchsize , -1 , width*height)
print("proj_key:" , proj_key.shape)
# energy : (1 , 81 , 81) 空间位置上每个位置相对与其他位置的注意力energy
energy = torch.bmm(proj_query , proj_key)
attention = self.softmax(energy) #对第三个维度求softmax,某一维度的所有行求softmax
proj_value = self.value_conv(x).view(m_batchsize , -1 , width*height)
print("proj_value : " , proj_value.shape)
#proj_value : (1,60,81) attetnion:(1,81,81) -> (1,60,81)
out = torch.bmm(proj_value , attention.permute(0,2,1)) #60行81列,每一行81个元素都是每个元素对其他位置的注意力权重乘以value后的值
out = out.view(m_batchsize , C , height , width)
out = (self.gamma*out + x).unsqueeze(-1)
return out
通道注意力 CAM_Module:
- 分别对A做reshape得到D(C,HxW),C(HxW,C),D与C做乘法得到(C,C),再经过softmax得到X
- 然后X与B做乘法得到(C,HxW),再乘以尺度系数β 最后reshape为(C,H,W)
- 最后与input相加
- β初始化为0,并逐渐学习得到更大的权重,是可学习参数
举例:
code:
class CAM_Module(nn.Module):
# 通道注意力模块
def __init__(self , in_dim) :
super(CAM_Module, self).__init__()
self.channel_in = in_dim
self.gamma = nn.Parameter(torch.zeros(1)) #可学习参数
self.softmax = torch.nn.Softmax(dim = -1)
def forward(self , x):
m_batchsize , C , height , width ,channel= x.size()
proj_query = x.view(m_batchsize , C , -1)
print("proj_query:" , proj_query.shape)
proj_key = x.view(m_batchsize , C , -1).permute(0 , 2 , 1)
print("proj_key : " , proj_key.shape)
energy = torch.bmm(proj_query , proj_key)
# print("energy:" , energy)
# expand_as(energy) 把tensor的形状扩展为energy一样的size
energy_new = torch.max(energy , -1 , keepdim = True)[0].expand_as(energy) - energy
# print(energy_new)
attention = self.softmax(energy_new)
print(attention.shape)
proj_value = x.view(m_batchsize , C , -1)
print(proj_value.shape)
out = torch.bmm(attention , proj_value)
print(out.shape)
out = out.view(m_batchsize , C , height , width,channel)
out = self.gamma*out + x
return out