文字识别east

代码:https://github.com/ishin-pie/east-mobilenet
网络结构是采用mobilenetv3的large,
取其中的几层,

class MobileNetV3(nn.Module):
    def __init__(self):
        super(MobileNetV3, self).__init__()
        self.extractor = MobileNetV3_Large()
        self.merge = Ushape()

    def forward(self, x):
        x, x1, x2, x3 = self.extractor(x)

        return self.merge(x, x1, x2, x3)

反正就是将mobilenetv3里边的各层混合,然后上采样,和前一层混合,

def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        # print(out.shape) torch.Size([2, 16, 256, 256])
        out1 = self.bneck1(out)
        # print(out1.shape)  torch.Size([2, 72, 128, 128]) up
        out2 = self.bneck2(out1)
        # print(out2.shape)  torch.Size([2, 240, 64, 64]) up
        out3 = self.bneck3(out2)
        # print(out3.shape)  torch.Size([2, 672, 32, 32]) up
        out = self.bneck4(out3)
        # print(out4.shape) torch.Size([2, 160, 16, 16])
        out = self.hs2(self.bn2(self.conv2(out)))
        # print(out.shape)  torch.Size([2, 960, 16, 16]) up
        out = self.linear(self.bn3(self.conv3(out)))
        # print(out.shape) torch.Size([2, 640, 16, 16])
        return out, out1, out2, out3,

如上图所示

class Ushape(nn.Module):
    def __init__(self):
        super(Ushape, self).__init__()

        self.conv1 = nn.Conv2d(1312, 320, 1)
        self.bn1 = nn.BatchNorm2d(320)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(320, 320, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(320)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(560, 160, 1)
        self.bn3 = nn.BatchNorm2d(160)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(160, 160, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(160)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.Conv2d(232, 64, 1)
        self.bn5 = nn.BatchNorm2d(64)
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.relu6 = nn.ReLU()

        self.conv7 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(64)
        self.relu7 = nn.ReLU()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, x1, x2, x3):
        # print(x.shape) 1 640 16 16
        y = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        # print(y.shape) 1 640 32 32
        y = torch.cat((y, x3), 1)  # 1 1312 32 32
        y = self.relu1(self.bn1(self.conv1(y)))
        y = self.relu2(self.bn2(self.conv2(y)))  # 1 320 32 32

        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)  # 1 320 64 64
        y = torch.cat((y, x2), 1)  # 1 560 64 64
        y = self.relu3(self.bn3(self.conv3(y)))
        y = self.relu4(self.bn4(self.conv4(y)))  # 1 160 64 64

        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)  # 1 160 128 128
        y = torch.cat((y, x1), 1)  # 1 232 128 128
        y = self.relu5(self.bn5(self.conv5(y)))
        y = self.relu6(self.bn6(self.conv6(y)))  # 1 64 128 128

        y = self.relu7(self.bn7(self.conv7(y)))
        return y

但是网络结构在上述混合特征层的基础上,加上下边几层

class East(BaseModel):

    def __init__(self, config):
        super().__init__(config)
        self.backbone = MobileNetV3()
        self.score_map = nn.Conv2d(64, 1, kernel_size=1)
        self.geo_map = nn.Conv2d(64, 4, kernel_size=1)
        self.angle_map = nn.Conv2d(64, 1, kernel_size=1)
        self.scale = config['data_loader']['input_size']

    def forward(self, inputs):
        inputs = self.backbone(inputs)
        score = torch.sigmoid(self.score_map(inputs))
        geo_map = torch.sigmoid(self.geo_map(inputs)) * self.scale
        angle_map = (torch.sigmoid(self.angle_map(inputs)) - 0.5) * math.pi / 2
        geometry = torch.cat([geo_map, angle_map], dim=1)

        return score, geometry

但是不理解这个是角度的map的文本框偏移的角度吗?
还是应该跑一下就可以理解了
好想参加比赛呀,队友,

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值