我们在models文件夹中找到common.py文件,加入如下模块
#####PSA注意力机制#######
class PSA_Channel(nn.Module):
def __init__(self, c1) -> None:
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = nn.Conv2d(c1, c_, 1)
self.cv2 = nn.Conv2d(c1, 1, 1)
self.cv3 = nn.Conv2d(c_, c1, 1)
self.reshape1 = nn.Flatten(start_dim=-2, end_dim=-1)
self.reshape2 = nn.Flatten()
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(1)
self.layernorm = nn.LayerNorm([c1, 1, 1])
def forward(self, x): # shape(batch, channel, height, width)
x1 = self.reshape1(self.cv1(x)) # shape(batch, channel/2, height*width)
x2 = se