Python中的setattr()和getattr()

【使用框架 PyTorch 1.2】
这是笔者在复现Alphago Zero的双头残差网络时遇到的情况。笔者准备使用具有19个相同残差块的残差网络,在国外博客中发现了这种写法:
在这里插入图片描述
当然,你用以下这种方式也可以达到相同的效果,但是显得太low了:

class ResNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNet, self).__init__()
        self.conv_head = conv3x3(in_channels, out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.block=nn.Sequential(
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels),
            ResidualBlock(out_channels,out_channels), 
        )

    def forward(self, x):
        x = self.relu(self.bn(self.conv_head(x)))
        x=self.block(x)
        return x

作为一个“要求甚解”的菜鸟,便开始学习setattr()和getattr()的用法。先看看官方文档里的说法。

setattr(object, name, value)
此函数与 getattr() 两相对应。 其参数为一个对象、一个字符串和一个任意值。 字符串指定一个现有属性或者新增属性。 函数会将值赋给该属性,只要对象允许这种操作。 例如,setattr(x, ‘foobar’, 123) 等价于 x.foobar = 123。

getattr(object, name[, default])
返回对象命名属性的值。name 必须是字符串。如果该字符串是对象的属性之一,则返回该属性的值。例如, getattr(x, ‘foobar’) 等同于 x.foobar。如果指定的属性不存在,且提供了 default 值,则返回它,否则触发 AttributeError。

这时候回看图片上的代码:

for block in range(BLOCKS):
            setattr(self, "res{}".format(block), BasicBlock(outplanes, outplanes))

含义:循环BLOCKS遍(比如BLOCKS=19),“res0”到“res18”都存储“BasicBlock(outplanes, outplanes)”。

for block in range(BLOCKS - 1):
            x = getattr(self, "res{}".format(block))(x)

含义:把x送进“res0”到“res18”存储的“BasicBlock(outplanes, outplanes)”中计算。

以下是笔者自己实现的一个小demo:

class Add():
    def __init__(self):
        super(Add, self).__init__()

    def add(self, a, b):
        out = a+b
        return out


class Result():
    def __init__(self):
        super(Result, self).__init__()

        for i in range(5):
            setattr(self, "res{}".format(i),
                    Add().add)

    def output(self):
        for i in range(5):
            print(getattr(self, "res{}".format(i))(i, i))


Result().output()

以上代码和以下代码等价:

def add(a, b):
    return a+b


for i in range(5):
    print(add(i, i))

Reference
[1] https://dylandjian.github.io/alphago-zero/?tdsourcetag=s_pctim_aiomsg
[2] https://www.cnblogs.com/zanjiahaoge666/p/7475225.html
[3] https://docs.python.org/zh-cn/3/index.html

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值