目前还没有进行训练,yolo验证无误
基于此文章进行修改CCNet
class CCBottleneck(nn.Module):
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, recurrence = 2):
super().__init__()
self.recurrence = recurrence
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.in_channels = c1
self.channels = c1 // 8
self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
self.ConvValue = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1)
self.SoftMax = nn.Softmax(dim=3)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
x0 = self.cv2(self.cv1(x))
x1 = x0
print('x1 is:',x1)
for i in range(self.recurrence):
b, _, h, w = x1.size()
# [b, c', h, w]
query = self.ConvQuery(x1)
# [b, w, c', h] -> [b*w, c', h] -> [b*w, h, c']
query_H = query.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).permute(0, 2, 1)
# [b, h, c', w] -> [b*h, c', w] -> [b*h, w, c']
query_W = query.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).permute(0, 2, 1)
# [b, c', h, w]
key = self.ConvKey(x1)
# [b, w, c', h] -> [b*w, c', h]
key_H = key.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
# [b, h, c', w] -> [b*h, c', w]
key_W = key.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
# [b, c, h, w]
value = self.ConvValue(x1)
# [b, w, c, h] -> [b*w, c, h]
value_H = value.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).float()
# [b, h, c, w] -> [b*h, c, w]
value_W = value.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).float()
if query_H.is_cuda:
inf = -1 * torch.diag(torch.tensor(float("inf")).cuda().repeat(h),0).unsqueeze(0).repeat(b*w,1,1)
else:
inf = -1 * torch.diag(torch.tensor(float("inf")).repeat(h),0).unsqueeze(0).repeat(b*w,1,1)
# print('inf is ', inf)
# print(query_H.is_cuda, inf.is_cuda)
# [b*w, h, c']* [b*w, c', h] -> [b*w, h, h] -> [b, h, w, h]
energy_H = (torch.bmm(query_H, key_H) + inf).view(b, w, h, h).permute(0, 2, 1, 3)
# energy_H = torch.bmm(query_H, key_H).view(b, w, h, h).permute(0, 2, 1, 3)
# [b*h, w, c']*[b*h, c', w] -> [b*h, w, w] -> [b, h, w, w]
energy_W = torch.bmm(query_W, key_W).view(b, h, w, w)
# [b, h, w, h+w] concate channels in axis=3
energy_total = torch.cat([energy_H, energy_W], 3)
# print('energy_total is ', energy_total)
concate = self.SoftMax(energy_total)
# print('concate is ', concate)
# [b, h, w, h] -> [b, w, h, h] -> [b*w, h, h]
attention_H = concate[:,:,:, 0:h].permute(0, 2, 1, 3).contiguous().view(b*w, h, h)
attention_W = concate[:,:,:, h:h+w].contiguous().view(b*h, w, w)
# [b*w, h, c]*[b*w, h, h] -> [b, w, c, h]
out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, w, -1, h).permute(0, 2, 3, 1)
out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, h, -1, w).permute(0, 2, 1, 3)
x1 = self.gamma*(out_H + out_W) + x1
# print('In cc x1 is:', x1)
# out = self.conv_out(x1)
out = x1.expand_as(x0)
# print('out x1 is:', x1)
return x + out if self.add else out
class C3CC(C3):
# C3 module with CABottleneck()
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*(CCBottleneck(c_, c_,shortcut) for _ in range(n)))