MindSpore-TOOD实现:模型权重迁移推理对齐实录

准备工作

环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1

基于自己编写的mindspore TOOD项目和MMDetection实现的pytorch权重来做迁移,

基于MindSpore实现TOOD forward 结构

先搭模型,结构就是resnet50+fpn+toodhead。除了模型结构,还要注意head以及fpn部分的权值初始化要与mmdetection中的实现对齐,这个在后续训练时会有影响

  • 两种框架下pad的区别需要注意,区别见MindSpore官方的迁移指南 ,我尽量使用显式表达,防止出错
  • resent50 backbone在训练时加载预训练权重进行初始化
  • mmdetection中FPN部分的初始化为xavier初始化,我在mindspore中采用更好的kaiming初始化
  • head部分卷积和一般性的偏置使用normal初始化以及zeros初始化
  • head部分的分类分支偏置采用的prob初始化
  • 其他部分(BN,GN)的初始化两个框架相同

权重转换

迁移其实就是在做权重的键值映射对齐,有了FCOS的迁移经验,且对网络模型部分做了命名优化,做这个会快很多。

可参考的经验:

打印两种框架的权重的名称及shape进行比对,
利用文本对比网站进行对比:
在这里插入图片描述
根据shape可以看到顺序完全对齐了,注意scale在pt中是一个浮点数,而在ms中是一个1x1的tensor。FPN实现的运算顺序也在代码中专门调试过,只需完成名称转换即可。

虽然可以根据顺序直接转换,但为了稳定性,还是用字典映射的方法,总结的名称转换方式如下(pytorch的名称改为mindspore的):

def tood_pth2ckpt():
    ms_ckpt = ms.load_checkpoint('tood_ms.ckpt')  # mindspore FCOS保存的随机权重
    pth = torch.load("/mnt/f/pretrain_weight/tood_r50_fpn_1x_coco.pth", map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称

    '''一般性的转换规则'''
    pt2ms = {'backbone': 'tood_body.backbone',  # backbone部分
             'neck': 'tood_body.fpn',
             'bbox_head': 'tood_body.head',
             'downsample': 'down_sample_layer',
             }

    '''conv层的转换规则, 一致,可忽略'''
    pt2ms_conv = {
        "weight": "weight",
        "bias": "bias",
    }

    '''downsample层的转换规则, 有卷积层和bn层, 分别为0,1命名,在torch中weight重复'''
    pt2ms_down = {
        "0.weight": "0.weight",
        "1.weight": "1.gamma",

        "1.bias": "1.beta",
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
    }

    '''BN层的转换规则'''
    pt2ms_bn = {
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
        "weight": "gamma",
        "bias": "beta",
    }

    '''GN层的转换规则'''
    pt2ms_gn = {
        "weight": "gamma",
        "bias": "beta",
    }

    for i, v in pth['state_dict'].items():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(v)
        '''一般性的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)

        '''conv层的转换规则, 一致,可忽略'''

        '''FPN部分特别处理'''
        if 'fpn' in pt_name:
            pt_name = pt_name.replace('.conv', '')

        '''下采样层特别处理'''
        if 'down' in pt_name:
            for k, v in pt2ms_down.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''GN层处理'''
        if 'gn' in pt_name:
            for k, v in pt2ms_gn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in ms_ckpt.keys():
            if 'scale' in pt_name:
                pt_value = torch.tensor([pt_value])
            assert pt_value.shape == ms_ckpt[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
            matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n-----------------------------未匹配的pt权重名称----------------------------')
    print('----------原名称--------                        ----------转换后名称---------')
    for j, v in not_match_pt_kv.items():
        print(j, np.array(v.shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n---------------------------未被匹配到的ms权重名称----------------------------')
    for j, v in ms_ckpt.items():
        if j not in matched_ms_k:
            print(j, np.array(v.shape))
    print('end')
    return match_pt_kv_mslist

输出:

-----------------------------未匹配的pt权重名称----------------------------
----------原名称--------                        ----------转换后名称---------
backbone.layer4.1.bn3.num_batches_tracked   tood_body.backbone.layer4.1.bn3.num_batches_tracked []
backbone.layer4.2.bn1.num_batches_tracked   tood_body.backbone.layer4.2.bn1.num_batches_tracked []
backbone.layer4.2.bn2.num_batches_tracked   tood_body.backbone.layer4.2.bn2.num_batches_tracked []
backbone.layer4.2.bn3.num_batches_tracked   tood_body.backbone.layer4.2.bn3.num_batches_tracked []
......

---------------------------未被匹配到的ms权重名称----------------------------
end

剩下一些bn层的num_batches_tracked状态,不需要管

接下来进行输出对齐,推理到需要padding的卷积时发现了一些问题,
mindspore中

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1, pad_mode='pad', has_bias=False)

不等价于pytorch的

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1)

查阅资料按道理应该等价的啊,结果不等价
发现是跟ms中这样等价的, 先pad,再valid卷积:

pad1 = ms.nn.Pad(((0,0),(0,0),(1,1),(1,1)))
conv2 = ms.nn.Conv2d(64, 64, kernel_size=3, stride=1,
                      pad_mode='valid')

不解。。。

如果手动pad那会有很大性能下降,所以只在测试时进行手动pad,训练时让模型自己纠正这种误差吧。

2023/07/05更新

由于之前写的模型框架有些地方不支持静态图编译,所以需要重新写。。。这样权重的name也变了(SequentialCell不支持索引,使用CellList构建,其中也没有使用dict构建conv-gn-relu。所以命名中的.conv变成了.0等。如果手动pad,那conv层会是.1后缀,因为.0是pad操作占位了),所以要重新进行权重迁移,思路跟之前一样。附上新的权重迁移代码:

def tood_pth2ckpt(ms_model, pth_path):
    # ms_ckpt = ms.load_checkpoint('tood_ms_zerospad.ckpt')  # mindspore FCOS保存的随机权重
    ms_ckpt = ms_model  # mindspore FCOS保存的随机权重
    pth = torch.load(pth_path, map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称

    '''一般性的转换规则'''
    pt2ms = {'backbone': 'tood_body.backbone',  # backbone部分
             'neck': 'tood_body.fpn',
             'bbox_head': 'tood_body.head',
             'downsample': 'down_sample_layer',
             }

    '''conv层的转换规则, 一致,可忽略'''
    pt2ms_conv = {
        "weight": "weight",
        "bias": "bias",
    }

    '''downsample层的转换规则, 有卷积层和bn层, 分别为0,1命名,在torch中weight重复'''
    pt2ms_down = {
        "0.weight": "0.weight",
        "1.weight": "1.gamma",

        "1.bias": "1.beta",
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
    }

    '''BN层的转换规则'''
    pt2ms_bn = {
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
        "weight": "gamma",
        "bias": "beta",
    }

    '''GN层的转换规则'''
    pt2ms_gn = {
        "weight": "gamma",
        "bias": "beta",
    }

    for i, v in pth['state_dict'].items():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(v)
        '''一般性的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)

        '''conv层的转换规则, 一致,可忽略'''

        '''FPN部分特别处理'''
        if 'fpn' in pt_name:
            pt_name = pt_name.replace('.conv', '')

        '''下采样层特别处理'''
        if 'down' in pt_name:
            for k, v in pt2ms_down.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''GN层处理'''
        if 'gn' in pt_name:
            for k, v in pt2ms_gn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''reduction_conv 和inter_convs 因为静态图重构了模型,需要特别处理'''
        if 'reduction_conv' in pt_name or 'inter_convs' in pt_name:
            if '.conv.' in pt_name:
                pt_name = pt_name.replace('.conv.', '.0.')
            elif '.gn.' in pt_name:
                pt_name = pt_name.replace('.gn.', '.1.')

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in ms_ckpt.keys():
            if not 'scale' in pt_name:
                assert pt_value.shape == ms_ckpt[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
            matched_ms_k.append(pt_name)

        # 由于手写 zeros pad对齐,导致mindspore有pad的卷积层命名发生改变,在没有手动pad的时候无需以下两个elif
        elif '.weight' in pt_name:
            if pt_name.replace('.weight', '.1.weight') in ms_ckpt.keys():
                pt_name = pt_name.replace('.weight', '.1.weight')
                assert pt_value.shape == ms_ckpt[pt_name].shape
                match_pt_kv[pt_name] = pt_value
                match_pt_kv_mslist.append(
                    {'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
                matched_ms_k.append(pt_name)
        elif '.bias' in pt_name:
            if pt_name.replace('.bias', '.1.bias') in ms_ckpt.keys():
                pt_name = pt_name.replace('.bias', '.1.bias')
                assert pt_value.shape == ms_ckpt[pt_name].shape
                match_pt_kv[pt_name] = pt_value
                match_pt_kv_mslist.append(
                    {'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
                matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n-----------------------------未匹配的pt权重名称----------------------------')
    print('----------原名称--------                        ----------转换后名称---------')
    for j, v in not_match_pt_kv.items():
        # if not 'num_batches_tracked' in j:
            print(j, np.array(v.shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n---------------------------未被匹配到的ms权重名称----------------------------')
    for j, v in ms_ckpt.items():
        if j not in matched_ms_k:
            print(j, np.array(v.shape))
    print('end')
    return match_pt_kv_mslist

模型(model+decode)在静态图模式测试集精度对齐(带手动pad):
Mindspore 静态图:
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.401
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.570
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.435
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.228
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.433
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.521
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.326
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.542
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.584
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.391
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.635
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.726

Pytorch mmdet:
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.4235
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.5951
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.4611
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.2512
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.4554
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.5551
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.6140
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.6140
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.6140
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.4161
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.6578
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.7789

差2%,发现BN层的计算也有一些差异,
这个我看推理的时候两个框架下模型的BN层的行为状态是一致,出现这种情况目前不知为啥。
感觉这个精度已经够了,所以先不纠结这个了,接下来完成TAL部分,开始训练。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值