Poetry上传一个属于自己的库

15 篇文章 1 订阅
13 篇文章 0 订阅
文章讲述了作者如何将torch.fx的代码封装成一个三方库,该库能自动解析Python文件中的nn.Module并进行可视化。作者通过Poetry进行了打包和上传到PyPI的过程,并提到工具尚不支持ONNX模型的绘制,因为torch.fx不支持动态流。最后,作者提到了未来可能涉及Rust与深度学习的内容。
摘要由CSDN通过智能技术生成

前言

其实这是一个拖了很久很久的坑,不知道多少人看过我之前的一篇博客关于torch.fx的使用,在这里面我用torch.fx实现了一些很有趣的功能比如模型可视化.所以当时就有一个想法,把代码封装一下写成一个属于自己的三方库,正好今天有点时间就把这个坑给填上.

这个工具的主要功能很简单,直接指定某个py文件工具会自动寻找文件中所有的nn.Module并进行解析可视化.

开始

关于模型的trace以及算子的解析在之前的博客中已经写的比较清楚了,这里就不过多赘述.今天的主要内容是封装+Poetry上传.

1 封装

将上次的代码直接拿来用,然后写个函数调用一下

def draw(model: torch.nn.Module, inputs: torch.Tensor, save_dir: str = './Save', save_name: str = 'model'):
    graph = model_graph(model, inputs)
    graph.render(outfile=save_dir + '/' + save_name, view=False)

然后解析py文件中所有的nn.Module

def parse_py(py_path: str) -> list[nn.Module]:
    spec = importlib.util.spec_from_file_location('module_name', py_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    module_classes = [m for _, m in module.__dict__.items() if
                      isinstance(m, (nn.Module, type)) or issubclass(type(m),nn.Module)]
    if len(module_classes) == 0:
        raise ValueError('No nn.Module class found in the file.')
    return module_classes
    
for index, model_name in enumerate(model_list):
    try:
        if not isinstance(model_name, nn.Module):
            model = model_name()
        else:
            model = model_name
        args.name=type(model).__name__+'.svg'
        draw(model, inputs, save_dir=args.dir, save_name=args.name)
    except:
        print(f"{model_name} draw failed")
        pass

这里为了防止某些子类算子不能实现所有输入需求,其实也是目前设计的不太灵活,所以直接用try简化这部分处理.

2 测试一下

输入的测试文件如下test.py

import torch
from torch import nn
from torchvision.models import resnet18,regnet_x_8gf
​
​
class TestConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(TestConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(True)
​
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
​
​
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = TestConv2d(3, 32, kernel_size=3)
        self.conv2 = TestConv2d(32, 64, kernel_size=3)
        self.dropout = nn.Dropout(0.3)
​
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.dropout(x)
        return x
​
​
model= resnet18()
model2=regnet_x_8gf()

image-20230627201236545
image-20230627201309469
image-20230627201508936

貌似还可以,接下来就是Poetry打包上传了

3 打包&&上传

首先下载Poetry,根据官网推荐利用curl -sSL https://install.python-poetry.org | python3 -进行安装,命令细节根据自己情况更改

使用poetry new xxx新建一个项目,这里项目的结构已经生成好了,只需要把之前的文件复制到对应位置.接下来就是比较重要的一步了,修改.toml文件
image-20230627201925328

这里对这本地环境pip list把包的相关依赖加进去,关于^与~的区别可以自行google一下版本命名规则相关的.接下来就去pypi上注册一个帐号,得到用户名和密码

最后进行上传

poetry publish --build -u username -p password

这里我没有去细讲关于poetry lock/poetry build/poetry show --tree等相关的指令,大家有兴趣可以去看相关内容.

image-20230627202529168

经过一系列操作终于在pypi上看到了我们自己的库,仅仅写了简单的readme没有去写license.

4 下载测试

我们从pypi上把这个库给install下来,并且在本地重新写个测试脚本调用库进行测试

from pytorch_show import main
​
if __name__ == '__main__':
    main()

残暴如此,直接调用main就好,再来写个model.py

from torchvision.models import vgg16
​
model= vgg16()

直接调用测试python test.py -f model.py

image-20230627203037295

image-20230627203110682

直筒形VGG一切正常

总结

今天也算是小小填了一下去年留下的坑,不过还是有很多遗憾的地方.本来这个工具的初衷除了能解析py文件之外还能对onnx进行绘制,但是今天实践下来有很多坑.一开始的想法是将onnx反转回pytorch模型,但是torch.fx的trace并不能支持动态流,因此很多算子包含if或者for loop的地方都会报错.然后尝试了直接利用onnx.tools.net_drawer进行绘制,成功得到了图片但是太过于复杂,很多多余的输入都被展示出来严重干扰了核心算子的展示,同时必须借助运行时才能得到每一步的shape,所以这部分想优化还是要好好想想办法,这也算是为下一阶段改进再留一个坑吧.

另外留个彩蛋,下一期会讲讲最近看到Rust与DL相关的内容.

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shelgi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值