setter方法和getter方法(python库和pytorch库)

python库中的setter方法和getter方法

setattr方法

setattr 是 Python 中的一个内置函数,用于设置对象的属性值。它的作用是给对象的指定属性设置一个新的值,如果该属性不存在则创建该属性。

setattr(object, name, value) 函数接受三个参数:

  • object:要设置属性值的对象。

  • name:要设置的属性名。

  • value:要设置的属性值。

setattr 函数会将 object 对象中名为 name 的属性的值设置为 value,如果 name 不存在,它会在 object 中创建一个名为 name 的属性,并将其值设置为 value

举个例子,假设有一个类 Person,它有一个名为 name 的属性,可以通过 setattr 函数来设置 name 属性的值:


class Person:
    def __init__(self, name):
        self.name = name

person = Person('Tom')
setattr(person, 'name', 'Jerry')
print(person.name)  # 输出:'Jerry'

结果


Jerry

上述代码中,首先创建了一个 Person 对象 person,然后使用 setattr 函数将 person 对象的 name 属性的值设置为 'Jerry',最后输出 person 对象的 name 属性的值,结果为 'Jerry'

另外,也可以使用 setattr 函数动态地为对象添加新的属性,例如:


class Person:
    pass

person = Person()
setattr(person, 'name', 'Tom')
print(person.name)  # 输出:'Tom'

结果


Tom

上述代码中,首先创建了一个空的 Person 对象 person,然后使用 setattr 函数为 person 对象添加一个名为 name 的属性,并将其值设置为 'Tom',最后输出 person 对象的 name 属性的值,结果为 'Tom'

getter方法

getattr 是 Python 中的一个内置函数,用于获取对象的属性值。它的作用是从一个对象中获取指定名称的属性值,如果该属性不存在,则会抛出一个 AttributeError 异常。

getattr(object, name[, default]) 函数接受三个参数:

  • object:要获取属性值的对象。

  • name:要获取的属性名。

  • default:可选参数,如果指定的属性不存在,则返回 default 指定的默认值,否则抛出 AttributeError 异常。如果不指定 default 参数,则抛出 AttributeError 异常。

getattr 函数会从 object 对象中获取名为 name 的属性的值,如果 name 不存在,它会根据 default 参数的值决定是否抛出异常或返回默认值。

举个例子,假设有一个类 Person,它有一个名为 name 的属性,可以通过 getattr 函数来获取 name 属性的值:


class Person:
    def __init__(self, name):
        self.name = name

person = Person('Tom')
name = getattr(person, 'name')
print(name)  # 输出:'Tom'

结果


Tom

上述代码中,首先创建了一个 Person 对象 person,然后使用 getattr 函数获取 person 对象的 name 属性的值,并将其赋值给变量 name,最后输出 name 的值,结果为 'Tom'

如果尝试获取一个对象中不存在的属性,getattr 函数会抛出一个 AttributeError 异常,例如:


class Person:
    def __init__(self, name):
        self.name = name

person = Person('Tom')
age = getattr(person, 'age')

结果:


Traceback (most recent call last):
  File "D:/code/python_code/TSM_pra/main.py", line 6, in <module>
    age = getattr(person, 'age')
AttributeError: 'Person' object has no attribute 'age'

上述代码中,尝试获取 person 对象的 age 属性的值,但是 Person 类中并没有定义 age 属性,因此 getattr 函数会抛出一个 AttributeError 异常。

pytorch库中的setter方法和getter方法

setter方法

torch.nn.Module 是 PyTorch 框架中一个重要的类,表示神经网络中的模块。在 torch.nn.Module 类中,有一个 setattr 方法,其作用与 Python 内置的 setattr 函数类似,用于设置模块中的属性值。

setattr 方法有两个参数:

  • name:属性名。

  • value:属性值。

setattr 方法将 name 属性的值设置为 value。需要注意的是,在 torch.nn.Module 中,调用 setattr 方法设置的属性值会被 PyTorch 自动跟踪并记录在模型中,从而支持模型参数的自动求导。

举个例子,假设有一个自定义的神经网络模块 MyNet,其中有一个名为 fc 的全连接层。我们可以使用 setattr 方法将 fc 层的权重参数和偏置参数分别设置为 weightbias


import torch
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = nn.Linear(10, 5)

net = MyNet()
setattr(net.fc, 'weight', torch.randn(5, 10))
setattr(net.fc, 'bias', torch.randn(5))

上述代码中,我们首先定义了一个名为 MyNet 的自定义模块,其中包含一个名为 fc 的全连接层。然后,我们创建了一个 MyNet 对象 net,并使用 setattr 方法设置 fc 层的权重参数和偏置参数,其中 weight 的形状为 (5, 10)bias 的形状为 (5,)

需要注意的是,在实际使用中,我们通常使用 nn.Module 类的 __setattr__ 方法来设置模块中的属性值,而不是直接调用 setattr 方法。这是因为 __setattr__ 方法会调用 PyTorch 内置的 _parameters_buffers 等属性来跟踪模型参数,从而保证参数的自动求导。例如:


import torch
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        x = self.fc(x)
        return x

net = MyNet()
net.fc.weight = nn.Parameter(torch.randn(5, 10))
net.fc.bias = nn.Parameter(torch.randn(5))

上述代码中,我们使用 PyTorch 内置的 nn.Parameter 类来将权重参数和偏置参数转换为可训练的参数,然后直接对 fc 层的 weightbias 属性赋值即可。这种方式比直接调用 setattr 方法更为方便和安全。

getter方法

torch.nn.Module 是 PyTorch 框架中一个重要的类,表示神经网络中的模块。在 torch.nn.Module 类中,有一个 getattr 方法,其作用与 Python 内置的 getattr 函数类似,用于获取模块中的属性值。

getattr 方法有一个参数:

  • name:属性名。

getattr 方法返回 name 属性的值。需要注意的是,在 torch.nn.Module 中,调用 getattr 方法获取的属性值可能会是一个子模块,因为在神经网络中通常会嵌套使用多个子模块。

举个例子,假设有一个自定义的神经网络模块 MyNet,其中有一个名为 fc 的全连接层。我们可以使用 getattr 方法获取 fc 层的权重参数和偏置参数:


import torch
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = nn.Linear(10, 5)

net = MyNet()
weight = getattr(net.fc, 'weight')
bias = getattr(net.fc, 'bias')

上述代码中,我们首先定义了一个名为 MyNet 的自定义模块,其中包含一个名为 fc 的全连接层。然后,我们创建了一个 MyNet 对象 net,并使用 getattr 方法获取 fc 层的权重参数和偏置参数,分别赋值给 weightbias 变量。

需要注意的是,如果 getattr 方法获取的属性值是一个子模块,我们可以通过调用子模块的 parameters 方法获取其所有参数,或者通过调用子模块的 children 方法获取其所有子模块。例如:


import torch
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

net = MyNet()
fc1_params = list(getattr(net, 'fc1').parameters())
fc2_children = list(getattr(net, 'fc2').children())

上述代码中,我们定义了一个名为 MyNet 的自定义模块,其中包含两个全连接层 fc1fc2。然后,我们创建了一个 MyNet 对象 net,并使用 getattr 方法获取 fc1 层和 fc2 层,分别调用它们的 parameterschildren 方法获取其所有参数和子模块。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值