程序主要使用到的函数总结如下:
# 1 访问模型参数
for name, param in net.named_parameters():
for param in net.parameters():
self.weight1 = nn.Parameter(torch.rand(20,20))
# 2 初始化模型参数
init.normal_(param, mean = 0, std = 0.01)
init.constant_(param, val = 0)
if 'weight' in name:
init.normal_(param, mean = 0, std = 0.01)
if 'bias' in name:
init.constant_(param, val = 0)
# 3 自定义初始化方法
def normal_(tensor, mean = 0, std = 1):
with torch.no_grad():
return torch.normal_(mean, std)
# 4 共享模型参数
print(id(net[0]) == id(net[1]))
print(id(net[0].weight)== id(net[1].weight))
模型参数的访问、初始化和共享
import torch
from torch import nn
from torch.nn import init
net = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3, 1))