pytorch中参数访问

一次性访问所有参数

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(2, 4)
out = net(X)

print(*[(name, param.shape) for name, param in net[0].named_parameters()])

去除print(),[  ]内的叫做列表推导式,用于遍历net[0](网络中的第一个子模块)的所有参数,并收集每个参数的名称和形状

  • net[0]:通过索引访问网络的第一个子模块。
  • .named_parameters():(字面意思:名字和参数),生成模块中所有参数的名称和参数本身的元组。
  • for name, param in net[0].named_parameters():for循环,遍历上述元组(参数名称, 参数)name 是参数的名称,param 是参数张量。
  • (name, param.shape) :这是for循环的返回值,是一个元组包含参数名称和参数形状
  • * 操作符在这里用于将列表解包,使得列表中的每个元素作为独立的参数传递给 print 函数
  • [('weight', torch.Size([8, 4])), ('bias', torch.Size([8]))] # 不带*
    ('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))  # 带*

    使用状态字典(state_dict()),通过键值来访问具体某个参数的数据,比如:

print(net.state_dict()['2.bias'].data)

输出:tensor([0.0575])

 通过 state_dict() 函数可以获得 net 的关于参数的键值,通过 [2.bias].data访问键的具体值(在这个例子里,1.relu没有参数, 2.bias 指的时nn.Linear(8,1).bias)

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值