pytorch学习12之延后初始化

文章介绍了在PyTorch中如何实现网络的延后初始化,允许在数据首次通过模型时动态确定层的大小。通过示例展示了如何使用nn.Sequential和nn.Linear创建网络,并解释了在不知道输入维度时如何利用nn.LazyLinear函数进行延迟初始化。在模型应用到实际数据后,参数会被自动初始化并打印出相应的权重矩阵形状。
摘要由CSDN通过智能技术生成

回顾之前的学习中,建立网络时:

  • 仅定义了网络架构,但没有指定输入维度。

  • 添加层时没有指定前一层的输出维度。

  • 在初始化参数时,甚至没有足够的信息来确定模型应该包含多少参数。

所以需要在框架的延后初始化(defers initialization)。到数据第一次通过模型传递时,框架动态判断每层大小。

实例化网络

初始化一个模型

import torch
from torch import nn


# 实例化网络
def getnet(in_features, out_features):
    net = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, out_features))
    return net


net = getnet(20, 10)
print(net)
# 使用pytorch制作一个自定义的nn模块,使网络从输入的数据中学习大小。
for name, param in net.named_parameters():
    print(name, '----', param.shape)

输出:

Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

0.weight ---- torch.Size([256, 20])
0.bias ---- torch.Size([256])
2.weight ---- torch.Size([10, 256])
2.bias ---- torch.Size([10])

延后初始化

输入维度,x是2,20,就可以定义第一层的权重矩阵,即W1是R 256 * 20。

进入第二层,将其维度定义为10*256,并通过计算图以此类推,在所有维度可用时将其绑定。

一旦知道这一点,我们就可以通过初始化参数来进行。

延迟初始化会引起混乱。

原因:在第一次正向计算之前,无法直接操作模型参数。例如,无法使用data和set_data函数来获取和修改参数。

解决方法:通过网络发送一个样本观测来强制初始化。

"""延后初始化"""


def init_weights(m):
    print("Init", m)


net.apply(init_weights)  # 每层都循环一下,最后对整个Sequential也进行操作
print('网络结构:\n', net)
print('第一层网络结构:\n', net[0].weight)
print('第二层网络结构:\n', net[2].weight)

x = torch.rand((2, 20))
y = net(x)  # Forward computation
for name, param in net.named_parameters():
    print(name, param.shape)


输出

Init Linear(in_features=20, out_features=256, bias=True)
Init ReLU()
Init Linear(in_features=256, out_features=10, bias=True)
Init Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)
网络结构:
 Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)
第一层网络结构:
 Parameter containing:
tensor([[-0.1813, -0.0647, -0.0027,  ...,  0.0237, -0.1742,  0.0907],
        [-0.1039, -0.1742, -0.0910,  ..., -0.0406, -0.2190, -0.1909],
        [ 0.1341,  0.2200, -0.0211,  ...,  0.0994,  0.0191,  0.0513],
        ...,
        [ 0.1448, -0.1018,  0.1805,  ..., -0.1612,  0.0956,  0.0072],
        [-0.1799,  0.1403, -0.1646,  ..., -0.0846,  0.0458, -0.2082],
        [-0.1763,  0.1748,  0.0277,  ..., -0.0043, -0.1371, -0.1358]],
       requires_grad=True)
第二层网络结构:
 Parameter containing:
tensor([[ 0.0470,  0.0207,  0.0165,  ..., -0.0037, -0.0545,  0.0299],
        [-0.0146,  0.0397,  0.0515,  ..., -0.0487, -0.0226,  0.0076],
        [ 0.0272, -0.0109,  0.0053,  ..., -0.0285, -0.0571, -0.0107],
        ...,
        [ 0.0485,  0.0430,  0.0483,  ..., -0.0312, -0.0467,  0.0132],
        [ 0.0242,  0.0481, -0.0282,  ...,  0.0625, -0.0525,  0.0211],
        [ 0.0292, -0.0370, -0.0082,  ...,  0.0034, -0.0544, -0.0559]],
       requires_grad=True)
0.weight torch.Size([256, 20])
0.bias torch.Size([256])
2.weight torch.Size([10, 256])
2.bias torch.Size([10])

pytorch延迟初始化

使用nn.LazyLinear函数

# PyTorch都是提前将网络输入输出指定的
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
# print(net[0].weight)  # 尚未初始化
print("net", net)
# [net[i].state_dict() for i in range(len(net))]
X = torch.rand(2, 20)
print("net(X)", net(X))
print(net)

输出:

net Sequential(
  (0): LazyLinear(in_features=0, out_features=256, bias=True)
  (1): ReLU()
  (2): LazyLinear(in_features=0, out_features=10, bias=True)
)
net(X) tensor([[ 0.3109, -0.3234,  0.1678, -0.2240, -0.0862, -0.3375,  0.1075, -0.1553,
         -0.0348,  0.0950],
        [ 0.4328, -0.3107,  0.1345, -0.2561, -0.0927, -0.3601,  0.2577, -0.1795,
          0.0985, -0.0108]], grad_fn=<AddmmBackward0>)
Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值