替换骨干网络之后使用预训练模型进行训练

前言

最近看了几篇使用transformer的文章,于是想用其中的一个transformer模块来替换另一个方法的骨干网络(backbone),替换完之后跑起来感觉没有什么效果,想着可能是transformer模型要用预训练会好一些。但是。由于是自己把原来方法的backbone替换掉,因此没有现成的直接可以使用的预训练模型来使用,只能从两个方法中提取相应模块的权重然后整合起来当做预训练模型使用。

原理

不同的预训练模型之所以能够东拼一块、西拼一块成为一个可以用的预训练模型,是因为在预训练模型中有相应的键值对key-value),只要把预训练模型中的对应到自己使用的网络中的键值对进行更新就好了。简单的举个例子:

# 使用的网络中有这样的减值对:
{
   'conv1': [1, 1, 1, 0, 0]}
# 加载的预训练模型中也有这样的键值对,但是值不同,这样你就可以通过字典更新的方式获得到预训练模型的训练权重了。
{
   'conv1': [1, 1, 1, 0, 0]} -> {
   'conv1': [0, 1, 0, 0, 1]}

但是值得注意的一点是,一定要根据自己构建的网络中的键值对来对预训练模型的键值对进行提取,否则将会更新失败。

更新backbone的预训练权重过程

1、查看网络的键值对和预训练模型的键值对

首先查看需要替换的backbone在原始的预训练模型中的键值对,因为这个是作为backbone使用,所以一般打印出来的信息会有‘backbone'几个字,挺好辨认的。代码和效果如下:

def extract_backbone():
	# 查看backbone部分的预训练模型键值对
    backbone_model_path = "backbone.pth"
    backbone_train_model = torch.load(backbone_model_path)
    print(backbone_train_model .keys())

打印结果如下:

odict_keys(['backbone.SA_modules.0.local_chunk.pe.0.conv.weight', 
'backbone.SA_modules.0.local_chunk.pe.0.bn.weight', ...., 
'bbox_head.vote_module.vote_conv.0.bn.bias']

可以看到在打印中的信息就能看到’backbone'几个字,也就能知道我们要提取的数据范围。
接下来查看网络的键值对:

    model = net()
    model_stat_dict = model.state_dict()
    print(model_stat_dict.keys())

打印结果:

odict_keys(['SA_modules.0.local_chunk.pe.0.conv.weight',
 'SA_modules.0.local_chunk.pe.0.bn.weight', ...]

可以看到里面的键的名字虽然不是完全一样'backbone.SA_modules.0.local_chunk.pe.0.conv.weight'对比'SA_modules.0.local_chunk.pe.0.conv.weight',但是能够知道之间的对应关系,也就能知道该怎么样更新字典的键值对。

2、提取键值对并更新

方法:观察法,观察对应的键值对相差什么样的字符串,然后把多余的字符串去掉,如,把'backbone.SA_modules.0.local_chunk.pe.0.conv.weight'换为'SA_modules.0.local_chunk.pe.0.conv.weight'

    # 提取backbone部分的权重
    backbone_stat_dict = {
   }
    for i in backbone_train_model.keys():
        if 'backbone' in i and 'FP_modules.1.mlps' not in i:
            backbone_stat_dict[i.replace('backbone.', '')] = backbone_train_model[i]

这里的方法是判断字符串是否在另一个字符串中来定位自己想要的键值对,另外一个就是,因为我把网络的最后一层的输出改变了,例如原始输出是256,现在改为288,那么最后一层的训练权重就不能用了,只能使用默认的值。这一步是需要自己debug或者慢慢观察的出来的,遇到什么bug就解决什么bug就好了。
接下来对自己新建的网络进行权值更新:

    # 更新网络权重
    model_stat_dict.update(backbone_stat_dict)
    model.load_state_dict(model_stat_dict)

更新非backbone部分的预训练权重过程

1、查看键值对

def extract_others():
    other_model_path = "xxx.pth"
    other_train_model = torch.load(other_model_path )
    print(other_train_model['model'].keys())

打印结果:

odict_keys(['module.backbone_net.sa1.mlp_module.layer0.conv.weight',...]

这里的代码跟提取backbone的差不多,多了一个[‘model’]是因为这个预训练模型的所有键值对放在一个叫model的字典中,保存的层次不一样而已,可以看到里面也有’backbone’几个字,但是这一次,提取的就不是‘backbone'部分的预训练权重了。这一次是要把‘backbone部分的训练权重丢掉,保留其他部分
具体代码:

    other_state_dict = {
   }
    for i in other_train_model['model'].keys():
        if 'backbone_net' not in i and i.replace('module.', '') in model_stat_dict .keys():
            other_state_dict [i.replace('module.', '')] = other_train_model['model'][i]

这里提取的是除backbone部分和在新建的网络中的键值对

2、更新键值对

    model_stat_dict.update(other_state_dict)
    model.load_state_dict(model_stat_dict)

结果对比

新建网络的默认权重:

('prediction_heads.5.bn1.weight', tensor([1., 1., 1.
  • 5
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
要将YOLOv5中的骨干网络替换为ShuffleNet,需要进行以下步骤: 1. 安装PyTorch和YOLOv5库,并下载ShuffleNet模型的权重。 2. 打开`models/yolo.py`文件,在`__init__`函数中找到骨干网络部分的代码,将其替换为ShuffleNet的代码。 3. 在`models/common.py`文件中,定义ShuffleNet的网络结构。 4. 加载ShuffleNet模型的权重。 以下是可能需要修改的`__init__`函数的示例代码: ```python class YOLOv5(nn.Module): def __init__(self, nc=80, anchors=(), ch=(), inference=False): # inference时只使用detect部分 super(YOLOv5, self).__init__() self.inference = inference self.stride = None # strides computed during build self.grid = None # exported onnx grid self.names = [''] * (nc if nc else 1) self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor self.nl = len(anchors) # number of detection layers self.na = len(anchors[0]) // 2 # number of anchors per layer self.anchor_grid = torch.tensor(anchors).view(self.nl, 1, -1, 1, 1, 2).to(next(self.parameters()).device) # normalized anchor grid self.register_buffer('anchors', self.anchor_grid.clone().view(self.nl, -1, 2)) # absolute anchors self.register_buffer('anchor_vec', self.anchor_grid.clone().view(self.nl, -1, 2).repeat(1, nc, 1)) # absolute anchor vector self.m = nn.ModuleList() self.save = [] self.ch = ch # input channels self.__construct() def __construct(self): # replace backbone with shufflenet backbone = shufflenet_v2_x1_0(pretrained=True) # remove last 2 layers (fc and avgpool) backbone.layers = nn.Sequential(*list(backbone.children())[:-2]) self.m.append(backbone) self.m.append(Conv(self.ch[-1], 512, 3, 2)) # 40 self.m.append(Bottleneck(512, 512)) self.m.append(Conv(512, 256, 3, 2)) # 80 self.m.append(Bottleneck(256, 256)) self.m.append(Conv(256, 256, 3, 2)) # 160 self.m.append(Bottleneck(256, 256)) self.m.append(Conv(256, 256, 3, 2)) # 320 self.m.append(Bottleneck(256, 256)) self.m.append(SPP(256, 256, [5, 9, 13])) self.m.append(Conv(512, 256, 1)) self.m.append(UpSample(2)) self.m.append(Conv(256 + 256, 256, 3, 1)) self.m.append(Bottleneck(256, 256, shortcut=False)) self.m.append(Conv(256, 128, 1)) self.m.append(UpSample(2)) self.m.append(Conv(128 + 256, 256, 3, 1)) self.m.append(Bottleneck(256, 256, shortcut=False)) self.m.append(Conv(256, 128, 1)) self.m.append(UpSample(2)) self.m.append(Conv(128 + 128, 256, 3, 1)) self.m.append(Bottleneck(256, 256, shortcut=False)) self.m.append(nn.Conv2d(256, self.no * self.na, 1)) self.export = [self.nl - 1] # detection layers self.freeze() ``` 这里我们使用预训练的ShuffleNet V2模型。需要安装shufflenet_v2模块,可以通过以下命令进行安装: ```python pip install shufflenet_v2_pytorch ``` 在上面的代码中,我们移除了ShuffleNet V2模型的最后两层(全连接层和平均池化层),并将其作为YOLOv5的骨干网络。然后,我们添加了YOLOv5的检测头部,用于检测目标。 最后,我们需要加载ShuffleNet V2模型的权重。可以使用以下代码加载ShuffleNet V2模型的权重: ```python model = shufflenet_v2_x1_0(pretrained=True) state_dict = torch.load('shufflenet_v2_x1_0.pth') model.load_state_dict(state_dict) ``` 请确保下载了ShuffleNet V2的预训练权重文件,并将其命名为`shufflenet_v2_x1_0.pth`。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值