ML-GCN(二)模型结构更改

背景:我们需要更改相应的模型。

(另,训练测试的时候,训练集上的mAP居然比验证集上的差,比如epoch =1的时候差了 30个百分点,可能程序编写有误,这点我们需要搞懂,查清代码原因。)

目录

一、模型定义

1.1 定义位置

1.2 输入参数

1.3 初始化的网络

1.4 结构关系

二、heads网络

2.1 输入参数

2.2 初始化网络结构

2.3 前馈结构

Global max pooling

两层fc

通过fc分组

GALayer1

group内提取特征

所有组及组内合并

三、GAT层替换为fc

3.1 GALayer1

GALayer的输入输出

FC_layer的参数设置

等价后的fc_1

四、本地实验汇总

4.1 exp1

4.2 exp2


一、模型定义

1.1 定义位置

in general_train.py

    # fixme=============begin=========
    if Config.MODEL == 'hgat_fc':
        import mymodels.hgat_fc as hgat_fc
        model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                                nclasses_per_group=Config.NCLASSES_PER_GROUP,
                                group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)

in hgat_fc.py

class BGATLayer

class BGATLayer(nn.Module):
    """
    Batch GATLayer, modified from:
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha):
        super(BGATLayer, self).__init__()

。。。

class HGAT_FC(nn.Module):
    def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):

1.2 输入参数

class HGAT_FC(nn.Module):
    def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
        super(HGAT_FC, self).__init__()

super() 函数是用于调用父类(超类)的一个方法。

http://www.runoob.com/python/python-func-super.html

self.groups = groups  组数
self.nclasses = nclasses  类别数
self.nclasses_per_group = nclasses_per_group  每组的类别数
self.group_channels = group_channels  组的channel
self.class_channels = class_channels   通道的channel
backbone 根据backbone确定model=resnet101或者resnet50
    model = HGAT_FC(backbone='resnet101', groups=12, nclasses=80,
                 nclasses_per_group=[1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7], group_channels=512, class_channels=256)

1.3 初始化的网络

初始化了三个网络,一个model,一个features,一个heads

model

        if backbone == 'resnet101':
            model = models.resnet101(pretrained=True)
        elif backbone == 'resnet50':
            model = models.resnet50(pretrained=True)
        else:
            raise Exception()

features (前馈网络),resnet101,或者resnet50,提取特征用

        self.features = nn.Sequential(
            model.conv1,
            model.bn1,
            model.relu,
            model.maxpool,
            model.layer1,
            model.layer2,
            model.layer3,
            model.layer4, )

heads(带GAT的网络,很重要,我们在二中讲

        self.heads = Head(self.groups, self.nclasses, 
                self.nclasses_per_group, group_channels=self.group_channels,
                          class_channels=self.class_channels)

1.4 结构关系

先通过

    def forward(self, x, inp):
        x = self.features(x)  # [B,2048,H,W]
        x = self.heads(x)
        return x

先通过features,再通过heads

先通过features获得feature-map,纬度[ B , 2048 , H , W ],batchSize,2048纬度的特征向量,高,宽

然后将输出通过heads

二、heads网络

2.1 输入参数

class Head(nn.Module):
    def __init__(self, groups, nclasses, nclasses_per_group, group_channels, class_channels):

2.2 初始化网络结构

给定一个网络定义了尺寸

#global pooling层
self.gmp = nn.AdaptiveMaxPool2d(1)

#fc层
self.reduce_fc = nn.Sequential(utils.BasicLinear(in_channels=2048, out_channels=1024),
                               utils.BasicLinear(in_channels=1024, out_channels=group_channels), )

#分组fc
self.group_fcs = nn.ModuleList(
    [utils.ResidualLinearBlock(in_channels=group_channels, reduction=2, out_channels=group_channels)
     for _ in range(groups)])

#类别fc
self.class_fcs = nn.ModuleList(
    [utils.BasicLinear(in_channels=group_channels, out_channels=class_channels) for _ in range(nclasses)])

#gat1,Graph attention层
self.gat1 = BGATLayer(in_features=group_channels, out_features=group_channels, dropout=0, alpha=0.2)

self.gat2s = nn.ModuleList(
    [BGATLayer(in_features=class_channels, out_features=class_channels, dropout=0, alpha=0.2) for _ in

2.3 前馈结构

先将上一步的进行global max pooling,

Global max pooling

输入纬度[ B , 2048 , H , W ]

作用Global max pooling

输出纬度[ B , 2048 ]

x = self.gmp(x).view(x.size(0), x.size(1))  # [B,2048]

两层fc

输入纬度[ B , 2048 ]

输出纬度[B, Group channels]

经过fc,得到,纬度[B,  Group_channels],B为batchsize,通过前面提取出来的2048纬度的向量,到1024神经元的隐层,再输出到group channels个输出。

C为group channels,定义为512.

        self.reduce_fc = nn.Sequential(utils.BasicLinear(in_channels=2048, out_channels=1024),
                                       utils.BasicLinear(in_channels=1024, out_channels=group_channels), )
x = self.reduce_fc(x)  # [ B , group channels]

通过fc分组

输入 [B ,C] 

C为group channels,512,然后经过 groups=12个fc层,输出到[ B,  groups ,  group_channels ]

x = torch.stack([self.group_fcs[i](x) for i in range(self.groups)], dim=1)  
# [ B,  groups ,  group_channels ]

定义在下面,groups数为12,

        self.group_fcs = nn.ModuleList(
            [utils.ResidualLinearBlock(in_channels=group_channels, reduction=2, out_channels=group_channels)
             for _ in range(groups)])

GALayer1

x = self.gat1(x)  # [B,groups,group_channels]

输入为[ B,  groups ,  group_channels ], 经过注意力网络,输出为[ B,  groups ,  group_channels ]纬度不变,但是每个groups有关联

self.gat1 = BGATLayer(in_features=group_channels, out_features=group_channels, dropout=0, alpha=0.2)

group内提取特征

for i in range(self.groups):
    inside = []
    for j in range(self.nclasses_per_group[i]):
        inside.append(self.class_fcs[count](x[:, i, :]))  # [B,Group_channels]
        count += 1
    inside = torch.stack(inside, dim=1)  # [B, nclasses_per_group ,Group_channels]
    inside = self.gat2s[i](inside)  # [B, nclasses_per_group, Group_channels]
    outside.append(inside)

遍历group,每个group内针对组内的每个类进行一次卷积,每类输入 [ B , Group_channels ]

通过nclasses_per_group 个fc输出  [B, nclasses_per_group ,Group_channels]

所有组内的类之间通过GAT,输出 [B, nclasses_per_group ,Group_channels]

所有组及组内合并

x = torch.cat(outside, dim=1)  # [B,nclasses, Group_channels]

前面遍历了组,现在把所有组内的类拼接在一起。

对所有组内做卷积

x = torch.cat([self.fcs[i](x[:, i, :]) for i in range(self.nclasses)], dim=1)  # [B,nclasses]
        self.fcs = nn.ModuleList(
            [nn.Sequential(
                utils.ResidualLinearBlock(in_channels=class_channels, reduction=2, out_channels=class_channels),
                nn.Linear(in_features=class_channels, out_features=1)
            ) for _ in range(nclasses)])

三、GAT层替换为fc

3.1 GALayer1

GALayer的输入输出

我们现在希望将GATLayer替换为fc

第一个GALayer

输入输出尺寸均为 [B,groups,group_channels],所以这个GALayer将groups当做节点,此节点可变

self.gat1 = BGATLayer(in_features=group_channels, out_features=group_channels, dropout=0, alpha=0.2)

替换后,需要FC_layer也是输入输出均为 [B,groups,group_channels]

FC_layer的参数设置

输入[ B ,2048] ,输出 [ B, group_channels]可见可以将B忽略。

self.reduce_fc = nn.Sequential(utils.BasicLinear(in_channels=2048, out_channels=1024),
                               utils.BasicLinear(in_channels=1024, out_channels=group_channels), )

等价于GALayer,为输入 [ B, groups * group_channels ] ,输出同样的groups * group_channels

等价后的fc_1

定义:

self.fc_replace_GALayer1=utils.BasicLinear(in_channels=groups*group_channels,out_channels=groups*group_channels)

前馈运算:

batch_size = x.size(2)
group_channels = x.size(0)
x=x.reshape([batch_size, self.groups*group_channels ])

x=fc_replace_GALayer1(x)

x=x.reshape([batch_size, self.groups , group_channels])

按照要求,

输入输出均为 [B,groups,group_channels]

先做尺寸转换,转换成 [ B, groups * group_channels ]再输入网络,然后再转换回来

[B,groups,group_channels]

四、本地实验汇总

运行前需要更改general_train.py之中的,

EXP_NAME,表示实验编号
RESUME,之前训练的结果,可以继续训练

4.1 exp1

代码中,用的正常的gat layer,没有变化

模型存储的位置

checkpoint/coco/exp_1/
checkpoint.pth.tar          model_best_65.4549.pth.tar  model_best_81.2804.pth.tar  model_best.pth.tar

4.2 exp2

首先,hgat_fc之中,把gat layer的相关直接屏蔽掉,只留其他结构。此实验在于对比说明idea 是work的。作为baseline

general_train.py之中,

EXP_NAME = 'exp_2'
RESUME = './checkpoint/coco/exp_2/model_best_79.8707.pth.tar'
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值