ConvNeXtv2 pytorch预训练权重转paddle

1 篇文章 0 订阅
1 篇文章 0 订阅

ConvNeXtv2 pytorch预训练权重转paddle

直接上代码

import torch
import paddle.fluid as fluid
from collections import OrderedDict
import paddle
import argparse
import os


parser = argparse.ArgumentParser()

parser.add_argument('--torch_weight', type=str,  default='convnextv2_atto_1k_224_ema.pt', help='torch_weight path')
parser.add_argument('--paddle_model', type=str,  default='convnextv2_atto', help='paddle_model')
parser.add_argument('--paddle_weight_dir', type=str,  default='paddle_model', help='paddle_model')

args = parser.parse_args()

if not os.path.exists(args.paddle_weight_dir):
    os.makedirs(args.paddle_weight_dir)

def main(args):
    # 读取torch的权重文件,我们这里默认读取的是.pt文件
    torch_weight = torch.load(args.torch_weight, map_location=torch.device('cpu'))
    
    weight = []
    # 对于.pt文件就是这么读取权重
    for torch_key in torch_weight['model'].keys():
        weight.append([torch_key,torch_weight['model'][torch_key].detach().numpy()])
    #     print(torch_key)
    # print(weight[0])


    with fluid.dygraph.guard():
        # 加载网络结构
        if args.paddle_model == 'convnextv2_atto':
            from convnextv2_paddle import convnextv2_atto
            paddle_model = convnextv2_atto()
            
        if args.paddle_model == 'convnextv2_femto':
            from convnextv2_paddle import convnextv2_femto
            paddle_model = convnextv2_femto()
            
        if args.paddle_model == 'convnext_pico':
            from convnextv2_paddle import convnext_pico
            paddle_model = convnext_pico()
            
        if args.paddle_model == 'convnextv2_nano':
            from convnextv2_paddle import convnextv2_nano
            paddle_model = convnextv2_nano()
            
        if args.paddle_model == 'convnextv2_tiny':
            from convnextv2_paddle import convnextv2_tiny
            paddle_model = convnextv2_tiny()
            
        if args.paddle_model == 'convnextv2_base':
            from convnextv2_paddle import convnextv2_base
            paddle_model = convnextv2_base()
            
        if args.paddle_model == 'convnextv2_large':
            from convnextv2_paddle import convnextv2_large
            paddle_model = convnextv2_large()   
            
        if args.paddle_model == 'convnextv2_huge':
            from convnextv2_paddle import convnextv2_huge
            paddle_model = convnextv2_huge()              
            
        # print(paddle_model)

        # 读取paddle网络结构的参数列表
        paddle_weight = paddle_model.state_dict()
        
        # # 检查是否paddle中的key在torch的dict中能找到
        # for paddle_key in paddle_weight:
        #     if paddle_key in torch_weight['model'].keys():
        #         print("Oh Yeah")
        #     else:
        #         print("No!!!")

        # 进行模型参数转换
        new_weight_dict = OrderedDict()

        # i = 0
        for paddle_key in paddle_weight.keys():
            # 首先要确保torch的权重里面有这个key,这样就可以避免DIY模型中一些小模块影响权重转换
            if paddle_key in torch_weight['model'].keys():
                # pytorch权重和paddle模型的权重为2维时需要转置,其余情况不需要
                if len(torch_weight['model'][paddle_key].detach().numpy().shape) == 2:
                    # print(paddle_key)
                    new_weight_dict[paddle_key] = torch_weight['model'][paddle_key].detach().numpy().T
                else:
                    new_weight_dict[paddle_key] = torch_weight['model'][paddle_key].detach().numpy()
        #     i += 1

        paddle_model.set_dict(new_weight_dict)
        fluid.dygraph.save_dygraph(paddle_model.state_dict(),os.path.join(args.paddle_weight_dir,args.paddle_model))
        
        print('Paddle version: ',paddle.__version__)
        print('Torch version: ',torch.__version__)
        print(f"You have converted {args.torch_weight} to {os.path.join(args.paddle_weight_dir,args.paddle_model)}.pdparams")
        
if __name__ == '__main__':
    main(args)

# 验证载入的权重    
# paddle_weight = paddle.load('conv_ne_xt_v2_0pdparams.pdparams')
# paddle_model2 = convnextv2_atto()
# paddle_model2.set_dict(paddle_weight)
# for i in range(100):
#     print(paddle_model.parameters()[i]==paddle_model2.parameters()[i])

使用方法

本代码运行的环境为:
Paddle version: 2.2.2
Torch version: 2.0.0+cpu

考虑到paddle社区一些最新的模型没有预训练权重,结合网上查的资料自己动手写了个pytorch预训练权重转paddle权重的代码。目前在ConvNextv2上进行了实践,使用还是很方便的,只需要导入paddle模型和对应的torch模型的预训练权重。这里convnextv2的paddle代码是根据pytorch代码对齐转换的,所以权重文件中的key都是相同的。之所以要保证key相同,是因为实践过程中,本人发现paddle模型的权重dict中元素的顺序与torch的权重文件并不一致,直接使用有序字典导入会一个都导不进去。然后paddle模型权重中的2维数据,即shape像<160,60>这种,在torch的权重文件中对应的shape应该是<60,160>。因此,代码中加入了判断,来确保权重shape的匹配。

参考链接

paddle复现pytorch踩坑(十一):转换pytorch预训练模型

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值