前言
期待着期待着,DRRG的代码解读还没有。这是要等死我啊,唉,没办法,没人写就自己摸索吧。没错,我又来吹牛了。全网第一篇DRRG代码解读来了,万事不求人,自己动手丰衣足食,今天和大家一起学习DRRG,希望大家喜欢。
textnet
先说我们要解读哪部分,首先我们的学习是针对代码主体结构。就是说我们最终的目的是做算法优化,所以我们也从代码的网络结构开始,而经过我的苦苦寻找,这部分就在network文件夹下的textnet.py中。那我们就开始解读他吧。
代码全文
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers import GCN
from layers import KnnGraph
from RoIlayer import RROIAlign
from layers import Graph_RPN
from network.vgg import VggNet
from network.resnet import ResNet
class UpBlok(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, upsampled, shortcut):
x = torch.cat([upsampled, shortcut], dim=1)
x = self.conv1x1(x)
x = F.relu(x)
x = self.conv3x3(x)
x = F.relu(x)
x = self.deconv(x)
return x
class FPN(nn.Module):
def __init__(self, backbone='vgg_bn', is_training=True):
super().__init__()
self.is_training = is_training
self.backbone_name = backbone
self.class_channel = 6
self.reg_channel = 2
if backbone == "vgg" or backbone == 'vgg_bn':
if backbone == 'vgg_bn':
self.backbone = VggNet(name="vgg16_bn", pretrain=True)
elif backbone == 'vgg':
self.backbone = VggNet(name="vgg16", pretrain=True)
self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
self.merge4 = UpBlok(512 + 256, 128)
self.merge3 = UpBlok(256 + 128, 64)
self.merge2 = UpBlok(128 + 64, 32)
self.merge1 = UpBlok(64 + 32, 32)
elif backbone == 'resnet50' or backbone == 'resnet101':
if backbone == 'resnet101':
self.backbone = ResNet(name="resnet101", pretrain=True)
elif backbone == 'resnet50':
self.backbone = ResNet(name="resnet50", pretrain=True)
self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
self.merge4 = UpBlok(1024 + 256, 256)
self.merge3 = UpBlok(512 + 256, 128)
self.merge2 = UpBlok(256 + 128, 64)
self.merge1 = UpBlok(64 + 64, 32)
else:
print("backbone is not support !")
def forward(self, x):
C1, C2, C3, C4, C5 = self.backbone(x)
up5 = self.deconv5(C5)
up5 = F.relu(up5)
up4 = self.merge4(C4, up5)
up4 = F.relu(up4)
up3 = self.merge3(C3, up4)
up3 = F.relu(up3)
up2 = self.merge2(C2, up3)
up2 = F.relu(up2)
up1 = self.merge1(C1, up2)
return up1, up2, up3, up4, up5
class TextNet(nn.Module):
def __init__(self, backbone='vgg', is_training=True):
super().__init__()
self.k_at_hop = [8, 4]
self.post_dim = 120
self.active_connection = 3
self.is_training = is_training
self.backbone_name = backbone
self.fpn = FPN(self.backbone_name, self.is_training)
self.gcn_model = GCN(600, 32) # 600 = 480 + 120
self.pooling = RROIAlign((3, 4), 1.0 / 1) # (32+8)*3*4 =480
# ##class and regression branch
self.out_channel = 8
self.predict = nn.Sequential(
#nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, self.out_channel, kernel_size=1, stride=1, padding=0)
)
# ## gcn branch
if is_training:
self.graph = KnnGraph(self.k_at_hop, self.active_connection, self.pooling, 120, self.is_training)
else:
self.graph = Graph_RPN(self.pooling, 120)
def load_model(self, model_path):
print('Loading from {}'.format(model_path))
state_dict = torch.load(model_path)
self.load_state_dict(state_dict['model'])
def forward(self, x, roi_data=None, to_device=None):
up1, up2, up3, up4, up5 = self.fpn(x)
predict_out = self.predict(up1)
graph_feat = torch.cat([up1, predict_out], dim=1)
feat_batch, adj_batch, h1id_batch, gtmat_batch = self.graph(graph_feat, roi_data)
gcn_pred = self.gcn_model(feat_batch, adj_batch, h1id_batch)
return predict_out, (gcn_pred, to_device(gtmat_batch))
def forward_test(self, img):
up1, up2, up3, up4, up5 = self.fpn(img)
predict_out = self.predict(up1)
return predict_out
def forward_test_graph(self, img):
up1, up2, up3, up4, up5 = self.fpn(img)
predict_out = self.predict(up1)
graph_feat = torch.cat([up1, predict_out], dim=1)
flag, datas = self.graph(img, predict_out, graph_feat)
feat, adj, cid, h1id, node_list, proposals, output = datas
if flag:
return None, None, None, output
adj, cid, h1id = map(lambda x: x.cuda(), (adj, cid, h1id))
gcn_pred = self.gcn_model(feat, adj, h1id)
pred = F.softmax(gcn_pred, dim=1)
edges = list()
scores = list()
node_list = node_list.long().squeeze().cpu().numpy()
bs = feat.size(0)
for b in range(bs):
cidb = cid[b].int().item()
nl = node_list[b]
for j, n in enumerate(h1id[b]):
n = n.item()
edges.append([nl[cidb], nl[n]])
scores.append(pred[b * (h1id.shape[1]) + j, 1].item())
edges = np.asarray(edges)
scores = np.asarray(scores)
return edges, scores, proposals, output
对应结构
这部分代码对应的结构是整个网络结构,也就是这张图:
万事预备,那么我们开始了哦。先说一点,任何py文件都是这样的,先写子函数、然后是主函数,所以我们从子函数开始,最后再讲主函数(代码顺序就是这样)下面进入今天第一个模块:UpBlok
UpBlok
class UpBlok(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, upsampled, shortcut):
x = torch.cat([upsampled, shortcut], dim=1)
x = self.conv1x1(x)
x = F.relu(x)
x = self.conv3x3(x)
x = F.relu(x)
x = self.deconv(x)
return x
首先这个模块对应的是图中的Up block部分,就是这个:
先说一下这部分做了什么:这部分主要做的是特征融合,先在通道维度上将两个不同的特征图进行融合,然后经过3x3卷积和1X1卷积改变通道数,最后经过反卷积扩大特征图大小,以备下一次特征融合。让我们对应一下代码吧:
首先:
def __init__(self, in_channels, out_channels):
这定义的是输入通道和输出通道数。
nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
这个就是pytorch的卷积函数,参数对应的是:输入通道数、输出通道数、卷积和尺寸、步长和padding。
nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
这个函数是反卷积函数,作用是扩大特征图,参数对应和上面相同。
所以这块就是定义了三个函数:
1X1卷积、3X3卷积、和反卷积。
然后我们看一下函数的返回:forward,看一下函数的具体运行流程。及函数的返回值
首先这个函数:
x = torch.cat([upsampled, shortcut], dim=1)
这个函数做的是特征图的拼接(沿通道维度)对应CONCAT。
然后是1X1卷积,激活函数、3X3卷积、激活函数、反卷积。这里说一下为什么顺序和图中不同,也就是为什么要先进行1X1卷积。这种做法可以先整合通道信息,增加通道数,再3X3卷积整合空间信息,这样可以有效地减少计算量。正常应该是这样的,但是他1x1卷积竟然没有改变通道数,这起我就有点不理解,后期理解了再来改变,好了讲完了,相信大家也可以对应上了,让我们开始下一部分:FPN
FPN
class FPN(nn.Module):
def __init__(self, backbone='vgg_bn', is_training=True):
super().__init__()
self.is_training = is_training
self.backbone_name = backbone
self.class_channel = 6
self.reg_channel = 2
if backbone == "vgg" or backbone == 'vgg_bn':
if backbone == 'vgg_bn':
self.backbone = VggNet(name="vgg16_bn", pretrain=True)
elif backbone == 'vgg':
self.backbone = VggNet(name="vgg16", pretrain=True)
self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
self.merge4 = UpBlok(512 + 256, 128)
self.merge3 = UpBlok(256 + 128, 64)
self.merge2 = UpBlok(128 + 64, 32)
self.merge1 = UpBlok(64 + 32, 32)
elif backbone == 'resnet50' or backbone == 'resnet101':
if backbone == 'resnet101':
self.backbone = ResNet(name="resnet101", pretrain=True)
elif backbone == 'resnet50':
self.backbone = ResNet(name="resnet50", pretrain=True)
self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
self.merge4 = UpBlok(1024 + 256, 256)
self.merge3 = UpBlok(512 + 256, 128)
self.merge2 = UpBlok(256 + 128, 64)
self.merge1 = UpBlok(64 + 64, 32)
else:
print("backbone is not support !")
def forward(self, x):
C1, C2, C3, C4, C5 = self.backbone(x)
up5 = self.deconv5(C5)
up5 = F.relu(up5)
up4 = self.merge4(C4, up5)
up4 = F.relu(up4)
up3 = self.merge3(C3, up4)
up3 = F.relu(up3)
up2 = self.merge2(C2, up3)
up2 = F.relu(up2)
up1 = self.merge1(C1, up2)
return up1, up2, up3, up4, up5
这块对应的就是整体的FPN部分,如下图:
这块主要做的是将特征金字塔的5个不同大小的特征图进行合并:具体流程如上图。现在我们逐句进行解读:
init部分还是参数的设置具体参数如下:
backbone='vgg_bn'
默认的特征提取网络:VGG16
self.is_training = is_training
self.backbone_name = backbone
self.class_channel = 6
self.reg_channel = 2
第一个参数不用在意,第二个参数是主干网络,第三个参数是分类通道数:6,第四个参数为回归参数通道数:2。加一起正好8个。
接下来的俩段是重复的,就是提取网络不同所以只用残差网络举例子:
elif backbone == 'resnet50' or backbone == 'resnet101':
if backbone == 'resnet101':
self.backbone = ResNet(name="resnet101", pretrain=True)
elif backbone == 'resnet50':
self.backbone = ResNet(name="resnet50", pretrain=True)
这块仍然是一个特征提取网络的选择,在resnet50和resnet101之间进行选择。
然后就是对整体的UpBlok函数的更改,做成一个整体的融合函数,并且赋予参数通道数。先说一下各个特征图的通道数:
c1:64、c2:256、c3:512、c4:1024、c5:2048,且这些特征图大小都是2倍关系c1最大。
现在讲合并关系,中的通道数。
self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
self.merge4 = UpBlok(1024 + 256, 256)
self.merge3 = UpBlok(512 + 256, 128)
self.merge2 = UpBlok(256 + 128, 64)
self.merge1 = UpBlok(64 + 64, 32)
第一个是反卷积,是用来把最后一个特征图扩大到和倒数第二个一样大小,输出通道数为:256,然后用来合并的。
第二句定义的是最底层的融合,融合c4和反卷积的c5,c4通道数为1024,融合后输出通道数为256.
第三句定义的是倒数第二层的融合,融合c3和上一步的结果,c3通道数为512,融合后输出通道数为128。
第四句定义的是第二层的融合,融合c2和上一步的结果,c2通道数为128,融合后输出通道数为64。
第五句定义的是第一层的融合,融合c1和上一步的结果,c1通道数为64,融合后输出通道数为32。
然后我们看一下返回结果(前面只是定义和顺序关系不大,这才是真正的过程):
C1, C2, C3, C4, C5 = self.backbone(x)
这块是通过主干网络输出5个特征图。
up5 = self.deconv5(C5)
up5 = F.relu(up5)
这是将C5 反卷积,然后通过激活函数。输出记为up5。
up4 = self.merge4(C4, up5)
up4 = F.relu(up4)
融合up5和C4,然后经过激活函数。输出记为up4。
up3 = self.merge3(C3, up4)
up3 = F.relu(up3)
融合up4和C3,然后经过激活函数。输出记为up3。
up2 = self.merge2(C2, up3)
up2 = F.relu(up2)
融合up3和C2,然后经过激活函数。输出记为up2。
up1 = self.merge1(C1, up2)
融合up2和C1,输出记为up1。
最后返回主体结果,特征提取和合并的部分就完成了,下面是整体部分:TextNet
TextNet
这一部分涉及到原论文没有提及的部分,所以我研究的也不是特别明白,所以先写着,然后等后续理解清楚再来补充。但是想理解结构的话暂时够用。
首先init部分还是初始化参数我们还是略过,一会用到的具体去谈。
同样我们先去解释几个重要的参数我们先抛出两个重要的图:
这里比较重要的参数就是:
self.pooling = RROIAlign((3, 4), 1.0 / 1) # (32+8)*3*4 =480
这里的(32+8)显然就是第一张图的输出FN然后的480就是RROIAlign的输出。输出的计算方法FN34然后做成一个1FN3*4二维张量,好吧具体我们不管他你只要知道他是RROIAlign所需的。
然后我们继续往下看,看比较重要的部分:
self.predict = nn.Sequential(
#nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, self.out_channel, kernel_size=1, stride=1, padding=0)
)
这个函数具体的操作行为你不需要知道,你可以简单的把它认为是把输出进行了卷积得到一个8通道的输出,最后和原来的32输出相加就得到了原来FN的40通道输出。
if is_training:
self.graph = KnnGraph(self.k_at_hop, self.active_connection, self.pooling, 120, self.is_training)
else:
self.graph = Graph_RPN(self.pooling, 120)
这块对应的是第二张图的第三个路径的操作,也看出了参数is_training的作用,但是实际上这是块的操作还多一点,我觉得是这部分:
如果不明白为什么还有融合这一步,可以去参看一下faster rcnn的rpn原理。接下来我们看一下函数的其他部分,看完以后你可能会理解这部分。
def load_model(self, model_path):
print('Loading from {}'.format(model_path))
state_dict = torch.load(model_path)
self.load_state_dict(state_dict['model'])
这块是模型的加载,了解即可。
然后我们看forward函数可以对应第三张图:
首先:
up1, up2, up3, up4, up5 = self.fpn(x)
通过前面的FPN网络算出几个输出,UP1是32通道的(最顶层的合并结果),然后通过刚才提到的函数:
predict_out = self.predict(up1)
得到一个8通道的输出。
graph_feat = torch.cat([up1, predict_out], dim=1)
32通道输出和8通道的输出结合成为新的输出FN。
然后我们回到那张图:
feat_batch, adj_batch, h1id_batch, gtmat_batch = self.graph(graph_feat, roi_data)
这个做的就是这个图:
做的事,得到GCN的输入。
最后输入GCN网络:
gcn_pred = self.gcn_model(feat_batch, adj_batch, h1id_batch)
得到最后的结果。
类似的,我们可以看这个函数:
def forward_test(self, img):
up1, up2, up3, up4, up5 = self.fpn(img)
predict_out = self.predict(up1)
return predict_out
可以看到这里是单独的得到8通道的输出。
至于最后一个函数,都是大同小异,这里说几个没提过的东西:
feat, adj, cid, h1id, node_list, proposals, output = datas
这个叫拆包,这样可以把一个变量的信息按照顺序分配给多个变量,是一种很常见的操作
pred = F.softmax(gcn_pred, dim=1)
这就是softmax计算,最计算机视觉的应该很熟悉。
node_list = node_list.long().squeeze().cpu().numpy()
这几个函数连在一起其实就是把一个多维向量拉直成为成为一长串,最后做个数据类型转换。
最后就是edges, scores的计算,大家了解一下就行,原文没想详细提,我也说不太清楚,如果以后我了解了会回来补充。
最后
textnet就写这么多,因为有的东西,要联系别的地方解释,有的东西本人也不是特别(本人也是个菜鸟)清楚所以很抱歉,先写这些后期还会有补充,也会推出其他部分注解,希望能帮到大家。