"""
参数管理
首先关注具有隐藏层的多层感知机
"""
import torch
from torch import nn
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
print(net(X))
"""
网络第二个块的参数
"""
print(net[2].state_dict())
"""
偏移的类型
"""
print(type(net[2].bias))
"""
偏移
"""
print(type(net[2].bias))
"""
偏移的数据
"""
print(net[2].bias.data)
net[2].weight.grad == None
"""
一次访问网络所有参数
"""
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])
"""
从嵌套块收集参数
"""
def block1():
return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())
def block2():
net = nn.Sequential()
for i in range(4):
net.add_module(f'block{i}', block1())
return net
rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
print(rgnet(X))
print(rgnet)
"""
内置初始化
"""
def init_normal(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
net.apply(init_normal)
print(net[0].weight.data[0], net[0].bias.data[0])
def init_constant(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight, 1)
nn.init.zeros_(m.bias)
net.apply(init_constant)
"""
对某些块应用不同的初始化方法
"""
def xavier(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
def init_42(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight, 42)
net[0].apply(xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
print(net[0].weight.data[0], net[0].bias.data[0])
"""
自定义初始化
"""
def my_init(m):
if type(m) == nn.Linear:
print(
"Init",
*[(name, param.shape) for name, param in m.named_parameters()][0]
)
nn.init.uniform_(m.weight, -10, 10)
m.weight.data *= m.weight.data.abs() >= 5
print(net[0].apply(my_init))
print(net[0].weight[:2])
"""
更简单更暴力的模型初始化
"""
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
print(net[0].weight.data[0])
"""
参数绑定
"""
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8, 1))
print(net(X))
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
print(net[2].weight.data[0] == net[4].weight.data[0])