代码:https://github.com/ishin-pie/east-mobilenet
网络结构是采用mobilenetv3的large,
取其中的几层,
class MobileNetV3(nn.Module):
def __init__(self):
super(MobileNetV3, self).__init__()
self.extractor = MobileNetV3_Large()
self.merge = Ushape()
def forward(self, x):
x, x1, x2, x3 = self.extractor(x)
return self.merge(x, x1, x2, x3)
反正就是将mobilenetv3里边的各层混合,然后上采样,和前一层混合,
def forward(self, x):
out = self.hs1(self.bn1(self.conv1(x)))
# print(out.shape) torch.Size([2, 16, 256, 256])
out1 = self.bneck1(out)
# print(out1.shape) torch.Size([2, 72, 128, 128]) up
out2 = self.bneck2(out1)
# print(out2.shape) torch.Size([2, 240, 64, 64]) up
out3 = self.bneck3(out2)
# print(out3.shape) torch.Size([2, 672, 32, 32]) up
out = self.bneck4(out3)
# print(out4.shape) torch.Size([2, 160, 16, 16])
out = self.hs2(self.bn2(self.conv2(out)))
# print(out.shape) torch.Size([2, 960, 16, 16]) up
out = self.linear(self.bn3(self.conv3(out)))
# print(out.shape) torch.Size([2, 640, 16, 16])
return out, out1, out2, out3,
如上图所示
class Ushape(nn.Module):
def __init__(self):
super(Ushape, self).__init__()
self.conv1 = nn.Conv2d(1312, 320, 1)
self.bn1 = nn.BatchNorm2d(320)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(320, 320, 3, padding=1)
self.bn2 = nn.BatchNorm2d(320)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(560, 160, 1)
self.bn3 = nn.BatchNorm2d(160)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(160, 160, 3, padding=1)
self.bn4 = nn.BatchNorm2d(160)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(232, 64, 1)
self.bn5 = nn.BatchNorm2d(64)
self.relu5 = nn.ReLU()
self.conv6 = nn.Conv2d(64, 64, 3, padding=1)
self.bn6 = nn.BatchNorm2d(64)
self.relu6 = nn.ReLU()
self.conv7 = nn.Conv2d(64, 64, 3, padding=1)
self.bn7 = nn.BatchNorm2d(64)
self.relu7 = nn.ReLU()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x, x1, x2, x3):
# print(x.shape) 1 640 16 16
y = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
# print(y.shape) 1 640 32 32
y = torch.cat((y, x3), 1) # 1 1312 32 32
y = self.relu1(self.bn1(self.conv1(y)))
y = self.relu2(self.bn2(self.conv2(y))) # 1 320 32 32
y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) # 1 320 64 64
y = torch.cat((y, x2), 1) # 1 560 64 64
y = self.relu3(self.bn3(self.conv3(y)))
y = self.relu4(self.bn4(self.conv4(y))) # 1 160 64 64
y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) # 1 160 128 128
y = torch.cat((y, x1), 1) # 1 232 128 128
y = self.relu5(self.bn5(self.conv5(y)))
y = self.relu6(self.bn6(self.conv6(y))) # 1 64 128 128
y = self.relu7(self.bn7(self.conv7(y)))
return y
但是网络结构在上述混合特征层的基础上,加上下边几层
class East(BaseModel):
def __init__(self, config):
super().__init__(config)
self.backbone = MobileNetV3()
self.score_map = nn.Conv2d(64, 1, kernel_size=1)
self.geo_map = nn.Conv2d(64, 4, kernel_size=1)
self.angle_map = nn.Conv2d(64, 1, kernel_size=1)
self.scale = config['data_loader']['input_size']
def forward(self, inputs):
inputs = self.backbone(inputs)
score = torch.sigmoid(self.score_map(inputs))
geo_map = torch.sigmoid(self.geo_map(inputs)) * self.scale
angle_map = (torch.sigmoid(self.angle_map(inputs)) - 0.5) * math.pi / 2
geometry = torch.cat([geo_map, angle_map], dim=1)
return score, geometry
但是不理解这个是角度的map的文本框偏移的角度吗?
还是应该跑一下就可以理解了
好想参加比赛呀,队友,