python基础——判断类的父类和实例的父类

python基础——判断类的父类和实例的父类

在做AutoML的过程中,用户使用训练器时会有两种情况,一种是将在NAS搜索到的模型文件传入Tuning模块,做完整的训练。另一种是在自建代码中独立调用Tuning模块。因此需要判断传入的是一个文件还是自建的模型。

这里我们简单创建一个torch模型来看一下

import torch
from torch import nn
from torch.nn import functional as F

class Models(nn.Module):
    def __init__(self):
        super(Models, self).__init__()
        self.conv1 = nn.Conv2d(3,12,3,2)
        self.bn = nn.BatchNorm2d(12)
        self.fc1 = nn.Linear(12*5*5, 10)
    
    def forward(self, x):
        x = F.relu(self.bn(self.conv1(x)))
        x = self.fc1(x)
        return x

model = Models()

对于任何一个类,python都内置了__base____bases__两种内置函数来显示其直接父类,如果有多余一个的直接父类,用__bases__显示所有直接父类。

pirnt(Models.__base__)
"""
output:<class 'torch.nn.modules.module.Module'>
"""

但是,一个类的实例是不包含__base__属性的,直接调用会报错,而采用type()函数返回的是我们自定义的这个实例的类名,无法作为判断依据。

print(model.__base__)
"""
output: Traceback (most recent call last):
  File "d:/LF/工作/test/torchtest.py", line 19, in <module>
    print(model.__bases__)
  File "D:\python\lib\site-packages\torch\nn\modules\module.py", line 585, in __getattr__
    type(self).__name__, name))
AttributeError: 'Models' object has no attribute '__bases__'
"""
print(type(model))
<class '__main__.Models'>

可以看到,这两种方法都无法让我们判断是否传入的模型是用户通过torch构建的模型。这里对于实例应采用__class__属性,返回实例所在的类,然后再调用__base__

print(model.__class__.__base__)
"""
output:<class 'torch.nn.modules.module.Module'>
"""

这样返回的就是所在类的父类了,通过assert或者if语句就可以判断用户的输入是否合法了。

发布了3 篇原创文章 · 获赞 0 · 访问量 107
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 深蓝海洋 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览