pytorch-git模型保存与使用

JIT(JustInTimeCompilation)是即时编译的概念,在深度学习中用于优化程序。PyTorch的JIT允许动态图模型转化为静态结构,便于部署和性能优化。JIT提供了模型的C++接口,提升了推断速度,并支持模型可视化。文章通过示例介绍了如何保存和使用JIT模型。
摘要由CSDN通过智能技术生成

1、什么是jit模型

JIT 是一种概念,全称是 Just In Time Compilation,中文译为「即时编译」,是一种程序优化的方法。

在深度学习中 JIT 的思想随处可见,最明显的例子就是 Keras 框架的 model.compile,TensorFlow 中的 Graph 也是一种 JIT,虽然他没有显示调用编译方法。但是 TensorFlow在调试方面很差强人意,需要使用 tf.cond 等 TensorFlow 自己开发的流程控制,增加了学习难度。

而pytorch 由于其动态图结构的特性,使得深受开发者喜爱,我们可以在 PyTorch 的模型前向中加任何 Python 的流程控制语句,甚至是下断点单步跟进都不会有任何问题。

2、为什么要用jit模型

但是这也引出了一个问题,由于tf的静态图设计,所以生成的模型结构很容易查看,而pytorch的模型是动态图,保存出来如果没有模型结构则完全无法使用,所以就需要jit方法了。那么JIT 到底带来了哪些好处。

(1)模型部署
PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…

(2) 性能提升

既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module)转换为 TorchScript Module,再进行推断。

(3) 模型可视化

TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 forward 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这里只介绍trace方式。

3、jit模型的保存

废话少说直接上代码

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(3,16,3,1,1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16,64,3,1,1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,32,3,2,1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,16,3,2,1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            )

        self.toRgb = nn.Conv2d(16,3,1,1)
        

    def forward(self, x):
        x = self.block(x)
        return self.toRgb(x)

n = Net()

example_forward_input = torch.rand(1, 3, 256, 256)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

module.save('a.pth')

4、jit模型的使用

import torch

device = torch.device('cpu')
module  = torch.jit.load('a.pth')
print(module)

module.to(device)


a = torch.randn(3,3,256,256)
y = module(a)
print(y.size())

5、今天就先介绍到这里,什么?为啥只介绍trace方式?,,,因为那一个我还没查。。。西柚!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
你可以使用 PyTorch-CycleGAN-and-pix2pix 库来使用预训练好的模型。下面是一个简单的步骤示例: 1. 首先,确保你已经安装了 PyTorch-CycleGAN-and-pix2pix 库。你可以使用以下命令安装: ``` pip install git+https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix ``` 2. 下载预训练模型。你可以在 CycleGAN 和 pix2pix 的模型网页(https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#model-checkpoints)上找到预训练模型的链接。下载并解压缩模型文件夹。 3. 创建一个配置文件。在模型文件夹中,复制并重命名 `test_opt.txt.example` 文件为 `test_opt.txt`。该文件用于配置测试参数。 4. 配置测试参数。打开 `test_opt.txt` 文件,并根据你的需求修改参数。重要的参数包括 `dataroot`(数据集的路径)和 `name`(模型名称)。 5. 运行测试脚本。使用以下命令运行测试脚本: ``` python test.py --dataroot ./path/to/dataset --name pretrained_model_name --model test_model_name ``` 确保将 `./path/to/dataset` 替换为你的数据集路径,`pretrained_model_name` 替换为你下载的预训练模型文件夹的名称,`test_model_name` 替换为你想要使用的测试模型的名称(如 `cycle_gan` 或 `pix2pix`)。 6. 查看结果。测试完成后,生成的结果将保存模型文件夹中的 `results` 子文件夹中。 请注意,这只是一个基本的示例,你可能需要根据你的具体情况进行更多的配置和调整。你可以参考 PyTorch-CycleGAN-and-pix2pix 库的文档以获取更多详细信息和用法示例。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值