最近看检测代码,经常会看到__setattr__和__getattr__这两个魔法属性,主要是对最后检测头的时候使用
我查询了一下:
__setattr__就等于在字典__dict__里面插入key和value
但是,__getattr__是在获取不到key的时候才用到这个,但为什么pytorch可以获取
import torch
import torch.nn as nn
class Animal(nn.Module):
def __init__(self,name,age) -> None:
super(Animal,self).__init__()
self.name = name
self.age = age
self.layer = nn.Conv2d(3,3,1,1)
self.__setattr__("lay",self.layer)
self.__setattr__("bb",'cc')
print(self.__dict__)
#('lay', Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)))]), 'name': '123', 'age': 12, 'bb': 'cc'}
def forward(self,x):
print( self.__getattr__('lay'))
print(self.__getattr__('bb'))
x = self.__getattr__('lay')(x)
return x
cc = Animal('123',12)
x = torch.rand(1,3,3,3)
print(cc(x))
# AttributeError: 'Animal' object has no attribute 'bb'
为什么layer可以,bb不可以呢,我查了下,原来pytorch 的nn.Module重写了这两个方法。。。进入这个Module方法进去可以看到有以下的方法
def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
.....
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
还有一种写法是这样
class IDAUp(nn.Module):
'''
IDAUp(channels[j], in_channels[j:], scales[j:] // scales[j])
ida(layers, len(layers) -i - 2, len(layers))
'''
def __init__(self, o, channels, up_f):
super(IDAUp, self).__init__()
for i in range(1, len(channels)):
c = channels[i]
f = int(up_f[i])
proj = DeformConv(c, o)
node = DeformConv(o, o)
up = nn.ConvTranspose2d(o, o, f * 2, stride=f,
padding=f // 2, output_padding=0,
groups=o, bias=False)
fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
def forward(self, layers, startp, endp):
for i in range(startp + 1, endp):
upsample = getattr(self, 'up_' + str(i - startp))
project = getattr(self, 'proj_' + str(i - startp))
layers[i] = upsample(project(layers[i]))
node = getattr(self, 'node_' + str(i - startp))
layers[i] = node(layers[i] + layers[i - 1])