pytorch的paramter

register_parameter

nn.Parameters 与 register_parameter 都会向 _parameters写入参数,但是后者可以支持字符串命名。
从源码中可以看到,nn.Parameters为Module添加属性的方式也是通过register_parameter向 _parameters写入参数。

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)
import torch
from torch import nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        print('before register:\n', self._parameters, end='\n\n')
        self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
        print('after register and before nn.Parameter:\n', self._parameters, end='\n\n')

        self.my_param2 = nn.Parameter(torch.randn(2, 2))
        print('after register and nn.Parameter:\n', self._parameters, end='\n\n')

    def forward(self, x):
        return x

mymodel = MyModel()

for k, v in mymodel.named_parameters():
    print(k, v)

程序返回为:

before register:
 OrderedDict()

after register and before nn.Parameter:
 OrderedDict([('my_param1', Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],
        [-0.4345, -0.9904, -0.9329],
        [ 1.4990, -1.7540, -0.4479]], requires_grad=True))])

after register and nn.Parameter:
 OrderedDict([('my_param1', Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],
        [-0.4345, -0.9904, -0.9329],
        [ 1.4990, -1.7540, -0.4479]], requires_grad=True)), ('my_param2', Parameter containing:
tensor([[ 1.0205, -1.3145],
        [-1.1108,  0.4288]], requires_grad=True))])

my_param1 Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],
        [-0.4345, -0.9904, -0.9329],
        [ 1.4990, -1.7540, -0.4479]], requires_grad=True)
my_param2 Parameter containing:
tensor([[ 1.0205, -1.3145],
        [-1.1108,  0.4288]], requires_grad=True)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值