torchvision.models._dict_[arch](pretrained=False)与python _dict_ 的关系
之前一直在写面向过程的C代码,转换到Python面向对象了有点不适应,前几天学习pytorch的时候,训练模型使用torchvision.models模块来加载预训练模型,记住了一种调用方法,如下:
model = torchvison.models.__dict__[arch](pretrained=True) //arch是需要加载的预训练模型名,比如‘resnet18’;
由于对python不熟悉,并不知道__dict__是python语言本身的属性,误以为这是torchvison的models模块,继承或者本身的属性。直到今天,我看大佬写的源码,也是同样的调用语句:
model = models.__dict__[arch](pretrained=True)
但是,此处的models,是源码中一个模块而已,和torchvison.models没有任何关系。源码结构是下图这样的,models是大佬自己写的要引入的模块:
至此,我才明白,这种用法是python本身的用法而已。下面就看一下,Python是如何实现这种调用的呢。首先,我查看了一下源文件,发现在model中的init文件并不是空的,里面的内容是如下:
from .resnext import resnext29_8_64, resnext29_16_64
from .resnet import resnet20, resnet32, resnet44, resnet56, resnet110
from .resnet_mod import resnet_mod20, resnet_mod32, resnet_mod44, resnet_mod56, resnet_mod110
from .preresnet import preresnet20, preresnet32, preresnet44, preresnet56, preresnet110
from .caffe_cifar import caffe_cifar
from .densenet import densenet100_12
# imagenet based resnet
from .imagenet_resnet import resnet18, resnet34, resnet50, resnet101, resnet152
# cifar based resnet
from .resnet import CifarResNet, ResNetBasicblock
# cifar based resnet pruned
from .resnet_small import resnet20_small, resnet32_small, resnet44_small, resnet56_small, resnet110_small
# imagenet based resnet pruned
# from .imagenet_resnet_small import resnet18_small, resnet34_small, resnet50_small, resnet101_small, resnet152_small
from .imagenet_resnet_small import resnet18_small, resnet34_small, resnet50_small, resnet101_small, resnet152_small
from .vgg_cifar10 import *
from .vgg import *
然后我注意到,此前models._dict_[arch]中的arch,正是其中import的一个函数名。于是我做了一个实验,在没有init的模块中,输出一下其__dict__,代码如下:a= functions.__dict__
,此处的function正是上面文件夹截图中的function模块,此时,a的输出如下:
可以从上面截图看到,这个dict中,记录了function模块的一些信息。并没有模块中.py以及.py中函数的记录,联想到之前models模块的init.py代码,继续在function模块中加入一个新的__init__.py,并仿照models中的init代码,写入from .cal_flop import *
,(cal_flop是源码中function文件夹下名为cal_flop.py的文件),再debug看a的输出如下:
至此,在a的__dict__的字典中,就出现了init中import的几个函数,此处cal_flop是原.py文件的名字,basic‘’、‘bottle’、‘imagenet_flop’等都是该py文件内定义的函数,此时,如果想调用某个函数,例如,想名为调用imagenet_flop,假如该函数定义是:
def imagenet_flop(layer=18, prune_rate=1):
flop = 0
if layer == 18:
block = [2, 2, 2, 2]
conv_in_blcok = 2
elif layer == 34:
block = [3, 4, 6, 3]
conv_in_blcok = 2
elif layer == 50:
block = [3, 4, 6, 3]
conv_in_blcok = 3
elif layer == 101:
block = [3, 4, 23, 3]
conv_in_blcok = 3
elif layer == 152:
block = [3, 8, 36, 3]
conv_in_blcok = 3
else:
print("wrong layer")
channel = [64, 128, 256, 512]
width = [56, 28, 14, 7]
layer_interval = [conv_in_blcok * i for i in block]
layer_index = [sum(layer_interval[:k + 1]) for k in range(0, len(layer_interval))]
print(layer_index)
if layer in [18, 34]:
flop = basic(layer, layer_index, channel, width, prune_rate)
elif layer in [50, 101, 152]:
flop = bottle(layer, block, channel, width, prune_rate)
print('bottle structure')
print(flop)
return flop
那么使用下面语句就可以:
b= functions.__dict__['imagenet_flop'](18,1)//注意需要用括号
再看看如果没有括号的形式:b= functions.__dict__['imagenet_flop']
,这种情况下,如果输出b,则b是一个地址,猜测应该就是函数在内存中的首地址。b的输出如下:
print(b)
<function imagenet_flop at 0x000002D008169400>
ps:有理解不对的地方请多见谅~学习中!!!