DRRG:Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection:代码解读(textnet)

前言

期待着期待着,DRRG的代码解读还没有。这是要等死我啊,唉,没办法,没人写就自己摸索吧。没错,我又来吹牛了。全网第一篇DRRG代码解读来了,万事不求人,自己动手丰衣足食,今天和大家一起学习DRRG,希望大家喜欢。

textnet

先说我们要解读哪部分,首先我们的学习是针对代码主体结构。就是说我们最终的目的是做算法优化,所以我们也从代码的网络结构开始,而经过我的苦苦寻找,这部分就在network文件夹下的textnet.py中。那我们就开始解读他吧。

代码全文

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers import GCN
from layers import KnnGraph
from RoIlayer import RROIAlign
from layers import Graph_RPN
from network.vgg import VggNet
from network.resnet import ResNet


class UpBlok(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, upsampled, shortcut):
        x = torch.cat([upsampled, shortcut], dim=1)
        x = self.conv1x1(x)
        x = F.relu(x)
        x = self.conv3x3(x)
        x = F.relu(x)
        x = self.deconv(x)
        return x


class FPN(nn.Module):

    def __init__(self, backbone='vgg_bn', is_training=True):
        super().__init__()

        self.is_training = is_training
        self.backbone_name = backbone
        self.class_channel = 6
        self.reg_channel = 2

        if backbone == "vgg" or backbone == 'vgg_bn':
            if backbone == 'vgg_bn':
                self.backbone = VggNet(name="vgg16_bn", pretrain=True)
            elif backbone == 'vgg':
                self.backbone = VggNet(name="vgg16", pretrain=True)

            self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
            self.merge4 = UpBlok(512 + 256, 128)
            self.merge3 = UpBlok(256 + 128, 64)
            self.merge2 = UpBlok(128 + 64, 32)
            self.merge1 = UpBlok(64 + 32, 32)

        elif backbone == 'resnet50' or backbone == 'resnet101':
            if backbone == 'resnet101':
                self.backbone = ResNet(name="resnet101", pretrain=True)
            elif backbone == 'resnet50':
                self.backbone = ResNet(name="resnet50", pretrain=True)

            self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
            self.merge4 = UpBlok(1024 + 256, 256)
            self.merge3 = UpBlok(512 + 256, 128)
            self.merge2 = UpBlok(256 + 128, 64)
            self.merge1 = UpBlok(64 + 64, 32)
        else:
            print("backbone is not support !")

    def forward(self, x):
        C1, C2, C3, C4, C5 = self.backbone(x)
        up5 = self.deconv5(C5)
        up5 = F.relu(up5)

        up4 = self.merge4(C4, up5)
        up4 = F.relu(up4)

        up3 = self.merge3(C3, up4)
        up3 = F.relu(up3)

        up2 = self.merge2(C2, up3)
        up2 = F.relu(up2)

        up1 = self.merge1(C1, up2)

        return up1, up2, up3, up4, up5


class TextNet(nn.Module):

    def __init__(self, backbone='vgg', is_training=True):
        super().__init__()
        self.k_at_hop = [8, 4]
        self.post_dim = 120
        self.active_connection = 3
        self.is_training = is_training
        self.backbone_name = backbone
        self.fpn = FPN(self.backbone_name, self.is_training)
        self.gcn_model = GCN(600, 32)  # 600 = 480 + 120
        self.pooling = RROIAlign((3, 4), 1.0 / 1)  # (32+8)*3*4 =480

        # ##class and regression branch
        self.out_channel = 8
        self.predict = nn.Sequential(
            #nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, self.out_channel, kernel_size=1, stride=1, padding=0)
        )

        # ## gcn branch
        if is_training:
            self.graph = KnnGraph(self.k_at_hop, self.active_connection, self.pooling, 120, self.is_training)
        else:
            self.graph = Graph_RPN(self.pooling, 120)

    def load_model(self, model_path):
        print('Loading from {}'.format(model_path))
        state_dict = torch.load(model_path)
        self.load_state_dict(state_dict['model'])

    def forward(self, x, roi_data=None, to_device=None):
        up1, up2, up3, up4, up5 = self.fpn(x)
        predict_out = self.predict(up1)

        graph_feat = torch.cat([up1, predict_out], dim=1)
        feat_batch, adj_batch, h1id_batch, gtmat_batch = self.graph(graph_feat, roi_data)
        gcn_pred = self.gcn_model(feat_batch, adj_batch, h1id_batch)

        return predict_out, (gcn_pred, to_device(gtmat_batch))

    def forward_test(self, img):
        up1, up2, up3, up4, up5 = self.fpn(img)
        predict_out = self.predict(up1)

        return predict_out

    def forward_test_graph(self, img):
        up1, up2, up3, up4, up5 = self.fpn(img)
        predict_out = self.predict(up1)

        graph_feat = torch.cat([up1, predict_out], dim=1)

        flag, datas = self.graph(img, predict_out, graph_feat)
        feat, adj, cid, h1id, node_list, proposals, output = datas
        if flag:

            return None, None, None, output

        adj, cid, h1id = map(lambda x: x.cuda(), (adj, cid, h1id))
        gcn_pred = self.gcn_model(feat, adj, h1id)

        pred = F.softmax(gcn_pred, dim=1)

        edges = list()
        scores = list()
        node_list = node_list.long().squeeze().cpu().numpy()
        bs = feat.size(0)

        for b in range(bs):
            cidb = cid[b].int().item()
            nl = node_list[b]
            for j, n in enumerate(h1id[b]):
                n = n.item()
                edges.append([nl[cidb], nl[n]])
                scores.append(pred[b * (h1id.shape[1]) + j, 1].item())

        edges = np.asarray(edges)
        scores = np.asarray(scores)

        return edges, scores, proposals, output

对应结构

这部分代码对应的结构是整个网络结构,也就是这张图:在这里插入图片描述
万事预备,那么我们开始了哦。先说一点,任何py文件都是这样的,先写子函数、然后是主函数,所以我们从子函数开始,最后再讲主函数(代码顺序就是这样)下面进入今天第一个模块:UpBlok

UpBlok

class UpBlok(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, upsampled, shortcut):
        x = torch.cat([upsampled, shortcut], dim=1)
        x = self.conv1x1(x)
        x = F.relu(x)
        x = self.conv3x3(x)
        x = F.relu(x)
        x = self.deconv(x)
        return x

首先这个模块对应的是图中的Up block部分,就是这个:
在这里插入图片描述
先说一下这部分做了什么:这部分主要做的是特征融合,先在通道维度上将两个不同的特征图进行融合,然后经过3x3卷积和1X1卷积改变通道数,最后经过反卷积扩大特征图大小,以备下一次特征融合。让我们对应一下代码吧:
首先:

def __init__(self, in_channels, out_channels):

这定义的是输入通道和输出通道数。

nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

这个就是pytorch的卷积函数,参数对应的是:输入通道数、输出通道数、卷积和尺寸、步长和padding。

nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)

这个函数是反卷积函数,作用是扩大特征图,参数对应和上面相同。
所以这块就是定义了三个函数:
1X1卷积、3X3卷积、和反卷积。
然后我们看一下函数的返回:forward,看一下函数的具体运行流程。及函数的返回值

首先这个函数:

 x = torch.cat([upsampled, shortcut], dim=1)

这个函数做的是特征图的拼接(沿通道维度)对应CONCAT。
然后是1X1卷积,激活函数、3X3卷积、激活函数、反卷积。这里说一下为什么顺序和图中不同,也就是为什么要先进行1X1卷积。这种做法可以先整合通道信息,增加通道数,再3X3卷积整合空间信息,这样可以有效地减少计算量。正常应该是这样的,但是他1x1卷积竟然没有改变通道数,这起我就有点不理解,后期理解了再来改变,好了讲完了,相信大家也可以对应上了,让我们开始下一部分:FPN

FPN

class FPN(nn.Module):

    def __init__(self, backbone='vgg_bn', is_training=True):
        super().__init__()

        self.is_training = is_training
        self.backbone_name = backbone
        self.class_channel = 6
        self.reg_channel = 2

        if backbone == "vgg" or backbone == 'vgg_bn':
            if backbone == 'vgg_bn':
                self.backbone = VggNet(name="vgg16_bn", pretrain=True)
            elif backbone == 'vgg':
                self.backbone = VggNet(name="vgg16", pretrain=True)

            self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
            self.merge4 = UpBlok(512 + 256, 128)
            self.merge3 = UpBlok(256 + 128, 64)
            self.merge2 = UpBlok(128 + 64, 32)
            self.merge1 = UpBlok(64 + 32, 32)

        elif backbone == 'resnet50' or backbone == 'resnet101':
            if backbone == 'resnet101':
                self.backbone = ResNet(name="resnet101", pretrain=True)
            elif backbone == 'resnet50':
                self.backbone = ResNet(name="resnet50", pretrain=True)

            self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
            self.merge4 = UpBlok(1024 + 256, 256)
            self.merge3 = UpBlok(512 + 256, 128)
            self.merge2 = UpBlok(256 + 128, 64)
            self.merge1 = UpBlok(64 + 64, 32)
        else:
            print("backbone is not support !")

    def forward(self, x):
        C1, C2, C3, C4, C5 = self.backbone(x)
        up5 = self.deconv5(C5)
        up5 = F.relu(up5)

        up4 = self.merge4(C4, up5)
        up4 = F.relu(up4)

        up3 = self.merge3(C3, up4)
        up3 = F.relu(up3)

        up2 = self.merge2(C2, up3)
        up2 = F.relu(up2)

        up1 = self.merge1(C1, up2)

        return up1, up2, up3, up4, up5


这块对应的就是整体的FPN部分,如下图:
在这里插入图片描述
这块主要做的是将特征金字塔的5个不同大小的特征图进行合并:具体流程如上图。现在我们逐句进行解读:
init部分还是参数的设置具体参数如下:

backbone='vgg_bn' 

默认的特征提取网络:VGG16

self.is_training = is_training
self.backbone_name = backbone
self.class_channel = 6
self.reg_channel = 2

第一个参数不用在意,第二个参数是主干网络,第三个参数是分类通道数:6,第四个参数为回归参数通道数:2。加一起正好8个。
接下来的俩段是重复的,就是提取网络不同所以只用残差网络举例子:

        elif backbone == 'resnet50' or backbone == 'resnet101':
            if backbone == 'resnet101':
                self.backbone = ResNet(name="resnet101", pretrain=True)
            elif backbone == 'resnet50':
                self.backbone = ResNet(name="resnet50", pretrain=True)

这块仍然是一个特征提取网络的选择,在resnet50和resnet101之间进行选择。
然后就是对整体的UpBlok函数的更改,做成一个整体的融合函数,并且赋予参数通道数。先说一下各个特征图的通道数:
c1:64、c2:256、c3:512、c4:1024、c5:2048,且这些特征图大小都是2倍关系c1最大。
现在讲合并关系,中的通道数。

			self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
            self.merge4 = UpBlok(1024 + 256, 256)
            self.merge3 = UpBlok(512 + 256, 128)
            self.merge2 = UpBlok(256 + 128, 64)
            self.merge1 = UpBlok(64 + 64, 32)

第一个是反卷积,是用来把最后一个特征图扩大到和倒数第二个一样大小,输出通道数为:256,然后用来合并的。
第二句定义的是最底层的融合,融合c4和反卷积的c5,c4通道数为1024,融合后输出通道数为256.
第三句定义的是倒数第二层的融合,融合c3和上一步的结果,c3通道数为512,融合后输出通道数为128。
第四句定义的是第二层的融合,融合c2和上一步的结果,c2通道数为128,融合后输出通道数为64。
第五句定义的是第一层的融合,融合c1和上一步的结果,c1通道数为64,融合后输出通道数为32。
然后我们看一下返回结果(前面只是定义和顺序关系不大,这才是真正的过程):

C1, C2, C3, C4, C5 = self.backbone(x)

这块是通过主干网络输出5个特征图。

up5 = self.deconv5(C5)
up5 = F.relu(up5)

这是将C5 反卷积,然后通过激活函数。输出记为up5。

up4 = self.merge4(C4, up5)
up4 = F.relu(up4)

融合up5和C4,然后经过激活函数。输出记为up4。

        up3 = self.merge3(C3, up4)
        up3 = F.relu(up3)

融合up4和C3,然后经过激活函数。输出记为up3。

        up2 = self.merge2(C2, up3)
        up2 = F.relu(up2)

融合up3和C2,然后经过激活函数。输出记为up2。

up1 = self.merge1(C1, up2)

融合up2和C1,输出记为up1。
最后返回主体结果,特征提取和合并的部分就完成了,下面是整体部分:TextNet

TextNet

这一部分涉及到原论文没有提及的部分,所以我研究的也不是特别明白,所以先写着,然后等后续理解清楚再来补充。但是想理解结构的话暂时够用。
首先init部分还是初始化参数我们还是略过,一会用到的具体去谈。
同样我们先去解释几个重要的参数我们先抛出两个重要的图:
在这里插入图片描述
在这里插入图片描述
这里比较重要的参数就是:

self.pooling = RROIAlign((3, 4), 1.0 / 1)  # (32+8)*3*4 =480

这里的(32+8)显然就是第一张图的输出FN然后的480就是RROIAlign的输出。输出的计算方法FN34然后做成一个1FN3*4二维张量,好吧具体我们不管他你只要知道他是RROIAlign所需的。
然后我们继续往下看,看比较重要的部分:

self.predict = nn.Sequential(
            #nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, self.out_channel, kernel_size=1, stride=1, padding=0)
        )

这个函数具体的操作行为你不需要知道,你可以简单的把它认为是把输出进行了卷积得到一个8通道的输出,最后和原来的32输出相加就得到了原来FN的40通道输出。

        if is_training:
            self.graph = KnnGraph(self.k_at_hop, self.active_connection, self.pooling, 120, self.is_training)
        else:
            self.graph = Graph_RPN(self.pooling, 120)

这块对应的是第二张图的第三个路径的操作,也看出了参数is_training的作用,但是实际上这是块的操作还多一点,我觉得是这部分:
在这里插入图片描述
如果不明白为什么还有融合这一步,可以去参看一下faster rcnn的rpn原理。接下来我们看一下函数的其他部分,看完以后你可能会理解这部分。

    def load_model(self, model_path):
        print('Loading from {}'.format(model_path))
        state_dict = torch.load(model_path)
        self.load_state_dict(state_dict['model'])

这块是模型的加载,了解即可。
然后我们看forward函数可以对应第三张图:
首先:

up1, up2, up3, up4, up5 = self.fpn(x)

通过前面的FPN网络算出几个输出,UP1是32通道的(最顶层的合并结果),然后通过刚才提到的函数:

 predict_out = self.predict(up1)

得到一个8通道的输出。

graph_feat = torch.cat([up1, predict_out], dim=1)

32通道输出和8通道的输出结合成为新的输出FN。
然后我们回到那张图:

feat_batch, adj_batch, h1id_batch, gtmat_batch = self.graph(graph_feat, roi_data)

这个做的就是这个图:在这里插入图片描述

做的事,得到GCN的输入。
最后输入GCN网络:

  gcn_pred = self.gcn_model(feat_batch, adj_batch, h1id_batch)

得到最后的结果。
类似的,我们可以看这个函数:

    def forward_test(self, img):
        up1, up2, up3, up4, up5 = self.fpn(img)
        predict_out = self.predict(up1)

        return predict_out

可以看到这里是单独的得到8通道的输出。
至于最后一个函数,都是大同小异,这里说几个没提过的东西:

   feat, adj, cid, h1id, node_list, proposals, output = datas

这个叫拆包,这样可以把一个变量的信息按照顺序分配给多个变量,是一种很常见的操作

pred = F.softmax(gcn_pred, dim=1)

这就是softmax计算,最计算机视觉的应该很熟悉。

node_list = node_list.long().squeeze().cpu().numpy()

这几个函数连在一起其实就是把一个多维向量拉直成为成为一长串,最后做个数据类型转换。
最后就是edges, scores的计算,大家了解一下就行,原文没想详细提,我也说不太清楚,如果以后我了解了会回来补充。

最后

textnet就写这么多,因为有的东西,要联系别的地方解释,有的东西本人也不是特别(本人也是个菜鸟)清楚所以很抱歉,先写这些后期还会有补充,也会推出其他部分注解,希望能帮到大家。

  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值