PyTorch2.0向后兼容性和加速效果浅探

前言

在PyTorch2022开发者大会上,PyTorch团队发布了一个新特性——torch.compile,将PyTorch的性能推向了新的高度,称这个新版本为PyTorch2.0。torch.compile的引入不影响之前的功能,其是一个完全附加和可选的功能,因此PyTorch2.0完全向后兼容,基于之前1.x版本开发的项目可以直接迁移到PyTorch2.0使用。

环境升级

比较简单,按照官方说明安装即可。
在这里插入图片描述
先建一个新环境torch2.0.1,python版本使用3.8+,在新环境中安装PyTorch2.0:

conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

测试向后兼容性

待所有依赖包安装好之后,切换到新环境。

conda activate torch2.0.1

运行之前torch1.x下能正常运行的网络训练代码,可以看到能够正常运行。此时速度没什么明显差别。

需要注意的是,如果是使用DDP模式训练的话,可能会报“local_rank”相关的错。将代码中的相关配置参数修改一下:

__author__ = 'TracelessLe'

import argparse
import torch


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    if (torch.__version__).startswith('2.0'):
        parser.add_argument("--local-rank", type=int, required=True)
    else:
        parser.add_argument("--local_rank", type=int, required=True)
    
    main()

测试加速效果

根据PyTorch官方博客中的内容,使用torch.compile后模型训练和推理的加速效果很明显。
在这里插入图片描述
这里快速上手,直接根据新手教程中的操作来修改相应代码:

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")
model(torch.randn(1,3,64,64))

在实验中发现自己的某个简单的网络训练速度由~0.8s/step加速到~0.6s/step,加速比达到25%。实践说明该新功能确实能够加速训练速度。

本次不深入测试更多的功能,包括不同的backend,以及纯推理过程的加速比。
在这里插入图片描述

其他说明

使用torch.compile功能时如果同时需要加载预训练模型,根据预训练模型保存的版本和正在使用的PyTorch版本的区别分情况进行处理:

1、预训练好的模型由PyTorch1.x保存,需要使用PyTorch2.0的torch.compile加速功能。则需要网络先加载模型参数,再使用torch.compile进行编译。

__author__ = 'TracelessLe'

import torch


device = 'cuda:0'
model_pth = 'pretrained_model.pth'
model = TrainNet()
model_state_dict = torch.load(model_pth, map_location=device)
model.load_state_dict(model_state_dict, strict=False)

if (torch.__version__).startswith('2.0'):
    model = torch.compile(model, backend="inductor")

2、预训练好的模型由PyTorch2.x保存,需要使用PyTorch2.0的torch.compile加速功能。则需要网络先编译再加载模型参数。

__author__ = 'TracelessLe'

import torch


device = 'cuda:0'
model_pth = 'pretrained_model.pth'
model = TrainNet()

if (torch.__version__).startswith('2.0'):
    model = torch.compile(model, backend="inductor")

model_state_dict = torch.load(model_pth, map_location=device)
model.load_state_dict(model_state_dict, strict=False)

当然,PyTorch2.0保存的模型PyTorch1.x也是可以正常加载的,只是需要注意的是模型中存的key有一定差异需要特殊处理一下。

其中,PyTorch2.0模型的key前缀是“_orig_mod.module.”,而PyTorch1.x模型的key前缀是“module.”。根据这个差异对模型加载过程特殊处理即可。

__author__ = 'TracelessLe'

import torch
import collections


def load_model_compile(model, model_pth, device, strict=False, backend="inductor"):
    # 兼容torch1/2大版本之间的模型加载
    origin_dict = torch.load(model_pth, map_location=device)
    state_dict = collections.OrderedDict()
    # torch1_model_prefix = 'module.'
    # offset1 = len(torch1_model_prefix)
    torch2_model_prefix = '_orig_mod.'
    offset2 = len(torch2_model_prefix)
    for key, value in origin_dict.items():
        if key.startswith(torch2_model_prefix):
            if (torch.__version__).startswith('2.0'):
                model = torch.compile(model, backend=backend)
                model.load_state_dict(origin_dict, strict=strict)
            else:
                for key, value in origin_dict.items():
                    state_dict[key[offset2: len(key)]] = value
                model.load_state_dict(state_dict, strict=strict)
        else:
            if (torch.__version__).startswith('2.0'):
                model.load_state_dict(origin_dict, strict=strict)
                model = torch.compile(model, backend=backend)
            else:
                model.load_state_dict(origin_dict, strict=strict) 
        break
    return model

当然,也可以直接改模型中参数的key以适配不同版本,此处不再展开。

针对PyTorch2.0的变化在官方博客中讲的很详细,需要深入应用的同学可以进一步查阅相关资料。

版权说明

本文为原创文章,独家发布在blog.csdn.net/TracelessLe。未经个人允许不得转载。如需帮助请email至tracelessle@163.com或扫描个人介绍栏二维码咨询。
在这里插入图片描述

参考资料

[1] PyTorch 2.0 重磅发布:一行代码提速 30% - 知乎
[2] Getting Started — PyTorch 2.0 documentation
[3] torch.compile — PyTorch 2.0 documentation
[4] 解决报错:train.py: error: unrecognized arguments: --local-rank=1 ERROR:torch.distributed.elastic.multipr_WTIAW.TIAW的博客-CSDN博客
[5] torch.compile — PyTorch 2.0 documentation
[6] Accelerated Image Segmentation using PyTorch | PyTorch
[7] Accelerated Generative Diffusion Models with PyTorch 2 | PyTorch
[8] PyTorch 2.0 | PyTorch
[9] torch.compile Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation
[10] Training Compiled PyTorch 2.0 with PyTorch Lightning

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TracelessLe

❀点个赞加个关注再走吧❀

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值