科研中的python细节(一)

最近在看PTQ4ViT的代码,也学到了很多做量化过程中的python细节,这里做一个总结,防止自己遗忘,加强一下记忆。

1.timm库中有许多已经预训练好的ViT模型(vision_transformer、Swin、DeiT),我们自己拿来用的时候直接使用net = timm.create_model(name, pretained=True),name是你需要的网络的名称,pretained决定是否需要预训练的参数,如果为True,那么权重都是预训练好的权重。

2.importlib库中有一个import_module函数,它可以将一个文件转化为一个实例对象,文件中的变量和方法可以通过该实例对象来调用,十分的方便。

3.如果调用函数的过程中,我们传入了一个字典cfg={"name":"111", "module":"222"}(字典中是键值对的映射关系),假设我们调用的函数是func(name, module),如果我们这样调用:func(**cfg),那么字典cfg会被解包,也就是字典中的键值对会变成关键字参数(name="111",module="222"),传递给函数,该用法也十分的方便

4.关于修改网络结构:在某一个模块中添加模块。首先,我们应该知道一个函数net.named_modules(),该函数可以以深度优先遍历的方式遍历该网络中的所有大大小小的模块,我们可以用元组(name,module)去接收该函数的返回值,如果我们想在比如Attention模块中添加矩阵乘模块,我们可以这样来写:

for name, module in net.named_modules():
    if isinstance(module, Attention):
        setattr(module, "matmul1", MatMul())
        setattr(module, "matmul2", MatMul())

5.关于修改网络结构:替换某父模块中的子模块。比如我想要替换ViT中的nn.Conv2d模块,用我自己写的模块MatMul()来替换,代码如下:

net = timm.create_model("vit_tiny_patch16_224", pretrained=True)
print(net.patch_embed)

'''
PatchEmbed(
  (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  (norm): Identity()
)
'''

setattr(net.patch_embed, "proj", MatMul())
print(net.patch_embed)

'''
PatchEmbed(
  (proj): MatMul()
  (norm): Identity()
)
'''

6.如果你的手头有一个网络对象net,你想看一看这个net的结构,直接print(net)即可,这样你就可以通过对网络结构的宏观了解来执行你对网络的一系列操作。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值