来自于 https://tangshusen.me/Dive-into-DL-PyTorch/#/
官方文档 https://pytorch.org/docs/stable/tensors.html
文章目录
模型构建
Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。下面继承Module类构造本节开头提到的多层感知机。这里定义的MLP类重载了Module类的__init__函数和forward函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。
import torch
from torch import nn
class MLP(nn.Module):
# 声明带有模型参数的层,这里声明了两个全连接层
def __init__(self,**kwargs):
# 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
# 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
super(MLP,self).__init__(**kwargs)
self.hidden = nn.Linear(784,256)
self.act = nn.ReLU()
self.output = nn.Linear(256,10)
# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
def forward(self,x):
a = self.act(self.hidden(x))
return self.output(a)
实例化MLP
x = torch.rand(2,784)
net = MLP()
print(net)
net(x)
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[-0.0370, -0.0794, 0.1573, -0.0460, -0.0318, -0.0653, -0.2214, 0.0340,
0.0324, 0.1943],
[-0.1231, 0.0116, 0.2601, -0.0035, 0.0520, -0.0406, -0.0722, 0.0060,
-0.1080, 0.3512]], grad_fn=<AddmmBackward>)
Module
的子类
Sequential
当模型的前向计算为简单串联各个层的计算时,Sequential类可以通过更加简单的方式定义模型。这正是Sequential类的目的:它可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。
下面我们实现一个与Sequential类有相同功能的MySequential类。这或许可以帮助读者更加清晰地理解Sequential类的工作机制。
class MySequential(nn.Module):
from collections import OrderedDict
def __init__(self,*args):
super(MySequential,self).__init__()
if len(args) ==1 and isinstance(args[0],OrderedDict):
for key,module in args[0].item():
self.add_module(key,module)
else:
for idx ,module in enumerate(args):
self.add_module(str(idx),module)
def forward(self,input):
for module in self._modules.values():
input=module(input)
return input
net = MySequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,10)
)
print(net)
net(x)
MySequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[ 0.1224, -0.0237, -0.0957, 0.0353, -0.0509, -0.1116, -0.0750, -0.0186,
-0.0620, 0.1194],
[ 0.1232, -0.0997, -0.1981, -0.0289, 0.0448, 0.0705, -0.1342, 0.0394,
-0.1696, 0.1937]], grad_fn=<AddmmBackward>)
ModuleList
ModuleList接收一个子模块的列表作为输入,然后也可以类似List那样进行append和extend操作:
net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
既然Sequential和ModuleList都可以进行列表化构造网络,那二者区别是什么呢。ModuleList仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现forward功能需要自己实现,所以上面执行net(torch.zeros(1, 784))会报NotImplementedError;而Sequential内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward功能已经实现。
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
另外,ModuleList不同于一般的Python的list,加入到ModuleList里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。
class Module_ModuleList(nn.Module):
def __init__(self):
super(Module_ModuleList,self).__init__()
self.lineears = nn.ModuleList([nn.Linear(10,10)])
class Module_List(nn.Module):
def __init__(self):
super(Module_List,self).__init__()
self.linears = [nn.Linear(10,10)]
net1 = Module_ModuleList()
net2 = Module_List()
print('net1:')
for p in net1.parameters():
print(p.size())
print('net2:')
for p in net2.parameters():
print(p.size())
net1:
torch.Size([10, 10])
torch.Size([10])
net2:
ModuleDict
ModuleDict接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
(act): ReLU()
(linear): Linear(in_features=784, out_features=256, bias=True)
(output): Linear(in_features=256, out_features=10, bias=True)
)
和ModuleList一样,ModuleDict实例仅仅是存放了一些模块的字典,并没有定义forward函数需要自己定义。同样,ModuleDict也与Python的Dict有所不同,ModuleDict里的所有模块的参数会被自动添加到整个网络中。
构造复杂的模型
虽然上面介绍的这些类可以使模型构造更加简单,且不需要定义forward函数,但直接继承Module类可以极大地拓展模型构造的灵活性。下面我们构造一个稍微复杂点的网络FancyMLP。在这个网络中,我们通过get_constant函数创建训练中不被迭代的参数,即常数参数。在前向计算中,除了使用创建的常数参数外,我们还使用Tensor的函数和Python的控制流,并多次调用相同的层。
class FancyMLP(nn.Module):
def __init__(self,**kwargs):
super(FancyMLP,self).__init__(**kwargs)
self.rand_weight = torch.rand((20,20),requires_grad=False)
self.linear = nn.Linear(20,20)
def forward(self,x):
x = self.linear(x)
x=nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)
# 复用全连接层。等价于两个全连接层共享参数
x=self.linear(x)
while x.norm().item()>1:
x/=2
if x.norm().item() <0.8:
x*=10
return x.sum()
X = torch.rand(2, 20)
net = FancyMLP()
print(net)
net(X)
FancyMLP(
(linear): Linear(in_features=20, out_features=20, bias=True)
)
tensor(-0.4913, grad_fn=<SumBackward0>)
因为FancyMLP和Sequential类都是Module类的子类,所以我们可以嵌套调用它们。
class NestMLP(nn.Module):
def __init__(self, **kwargs):
super(NestMLP, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())
def forward(self, x):
return self.net(x)
net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
X = torch.rand(2, 40)
print(net)
net(X)
Sequential(
(0): NestMLP(
(net): Sequential(
(0): Linear(in_features=40, out_features=30, bias=True)
(1): ReLU()
)
)
(1): Linear(in_features=30, out_features=20, bias=True)
(2): FancyMLP(
(linear): Linear(in_features=20, out_features=20, bias=True)
)
)
tensor(0.3448, grad_fn=<SumBackward0>)
访问模型参数
print(type(net.named_parameters()))
for name, param in net.named_parameters():
print(name, param.size())
<class 'generator'>
0.net.0.weight torch.Size([30, 40])
0.net.0.bias torch.Size([30])
1.weight torch.Size([20, 30])
1.bias torch.Size([20])
2.linear.weight torch.Size([20, 20])
2.linear.bias torch.Size([20])
可见返回的名字自动加上了层数的索引作为前缀。 我们再来访问net中单层的参数。对于使用Sequential类构造的神经网络,我们可以通过方括号[]来访问网络的任一层。索引0表示隐藏层为Sequential实例最先添加的层。
for name, param in net[0].named_parameters():
print(name, param.size(), type(param))
net.0.weight torch.Size([30, 40]) <class 'torch.nn.parameter.Parameter'>
net.0.bias torch.Size([30]) <class 'torch.nn.parameter.Parameter'>
返回的param的类型为torch.nn.parameter.Parameter,其实这是Tensor的子类,和Tensor不同的是如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里,来看下面这个例子。
class MyModel(nn.Module):
def __init__(self,**kwargs):
super(MyModel,self).__init__(**kwargs)
self.weight1 = nn.Parameter(torch.rand(20,20))
self.weight2 = torch.rand(20,20)
def forward(self,x):
pass
n = MyModel()
for name,param in n.named_parameters():
print(name)
weight1
因为Parameter是Tensor,即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度。
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
Y=net(X)
Y.backward()
print(weight_0.grad)
tensor([[ 0.0622, 0.0654, 0.1042, ..., -0.1514, 0.0846, 0.1257],
[ 0.0139, 0.0422, -0.1171, ..., 0.0142, -0.0565, -0.1016],
[ 0.0405, 0.1393, 0.0782, ..., 0.0151, -0.0972, -0.1105],
...,
[ 0.1332, 0.0998, -0.0605, ..., -0.0061, 0.0571, -0.1063],
[-0.1366, -0.1384, 0.0207, ..., -0.0695, 0.0810, -0.0280],
[ 0.0659, 0.1505, 0.1109, ..., -0.0085, -0.0435, -0.0354]])
None
tensor([[-0.0022, -0.0218, -0.0053, ..., -0.0030, -0.0048, -0.0103],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0864, 0.3693, 0.1030, ..., 0.1747, 0.1237, 0.4375],
...,
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0101, -0.0433, -0.0121, ..., -0.0205, -0.0145, -0.0514]])
初始化模型参数
from torch.nn import init
for name, param in net.named_parameters():
if 'weight' in name:
init.normal_(param, mean=0, std=0.01)
print(name, param.data)
0.net.0.weight tensor([[ 0.0216, -0.0093, 0.0016, ..., 0.0019, 0.0130, -0.0026],
[ 0.0042, -0.0017, 0.0056, ..., -0.0125, -0.0168, -0.0058],
[-0.0046, -0.0081, 0.0077, ..., 0.0097, -0.0123, -0.0045],
...,
[-0.0062, -0.0014, -0.0027, ..., 0.0026, 0.0015, -0.0152],
[-0.0068, -0.0066, -0.0049, ..., 0.0021, 0.0121, -0.0018],
[ 0.0132, 0.0139, 0.0135, ..., 0.0050, 0.0105, -0.0064]])
1.weight tensor([[ 2.0540e-02, 1.3424e-02, -1.4864e-02, -1.3467e-02, -9.4628e-03,
9.7067e-04, -1.5009e-02, -1.5539e-02, -7.9077e-03, 1.0551e-02,
-9.0278e-03, 1.7459e-02, -1.8690e-03, 5.0541e-03, -1.9378e-03,
1.6113e-02, 2.4179e-03, -3.9226e-03, -2.4951e-03, -1.0519e-02,
-1.7979e-03, -3.0638e-03, 5.2334e-04, 1.1851e-02, -1.5307e-02,
-2.3555e-02, -1.2260e-02, -8.4837e-03, 1.4811e-03, -8.6963e-03],
[-7.8196e-03, 1.4540e-02, 2.6614e-03, -1.0711e-02, 1.5188e-02,
3.4464e-04, -1.6673e-02, 1.4700e-02, -1.1430e-02, -3.0225e-03,
1.7065e-03, -1.7021e-02, 1.0656e-02, 1.1011e-02, 4.2047e-03,
-1.2465e-02, -6.2476e-03, -1.0956e-02, 3.1987e-03, -4.9426e-03,
-1.2004e-03, 4.4904e-03, -1.0595e-02, -1.1566e-02, 1.2842e-02,
2.6868e-03, 4.0029e-03, 1.3305e-02, 4.9962e-03, 3.2664e-03],
[-3.7490e-03, -9.1373e-03, -7.6418e-03, 1.1468e-02, -7.3277e-03,
1.0504e-02, -1.2510e-02, 2.0066e-02, -3.8254e-03, -1.3952e-03,
-8.9643e-03, -7.9015e-03, 3.0163e-03, 4.8269e-04, -1.6445e-02,
-4.2277e-03, -1.2489e-02, 7.0872e-03, -4.0340e-04, -6.6422e-03,
2.0632e-03, -4.4652e-03, -2.5231e-03, -1.1253e-02, 1.6601e-02,
-4.5058e-03, 4.3239e-03, 1.0291e-02, 2.0269e-02, -1.0369e-02],
[-1.3724e-02, -8.1422e-04, 1.7627e-03, -5.2986e-03, -3.4685e-03,
-2.0445e-03, -1.1192e-02, 1.0989e-02, -1.1012e-03, -1.7664e-02,
-3.6856e-03, 9.6825e-03, 1.5568e-03, 8.0502e-03, 5.8132e-04,
-9.6213e-03, -7.0347e-04, 5.8105e-03, 2.8272e-04, 1.1797e-02,
-7.6907e-03, 2.2619e-02, 1.2308e-02, 1.8024e-02, 1.7888e-03,
-1.7745e-03, 7.1993e-03, -3.5203e-04, -1.9954e-02, -1.8694e-02],
[ 7.0645e-03, 8.1665e-04, 1.4387e-02, 1.2969e-02, 1.8918e-03,
-5.8327e-03, -1.0239e-02, 8.8525e-03, -5.7654e-03, 1.7141e-03,
-2.8454e-03, 1.3615e-02, -1.0575e-02, -3.0326e-03, 6.0094e-03,
-1.0895e-02, 2.0619e-03, 1.5329e-03, 1.4229e-03, -3.2453e-03,
-1.7779e-03, 9.6478e-03, -2.3171e-02, 1.4701e-02, 1.9202e-02,
9.1253e-03, -6.1181e-03, -1.0587e-03, 2.0379e-04, 1.8809e-03],
[ 1.2714e-02, 3.9214e-03, 5.0589e-04, 6.2322e-03, 1.3992e-02,
1.4027e-03, -4.3150e-03, 1.0659e-03, -7.9712e-04, -1.9232e-02,
-1.0893e-02, 2.5453e-02, -2.4113e-03, -8.1766e-03, -5.1438e-03,
-1.5338e-02, -3.9863e-03, 5.6107e-03, -1.1374e-03, 5.2210e-03,
-1.6869e-02, -5.0722e-03, -6.0378e-03, 3.9245e-05, -1.0497e-02,
2.8167e-03, 1.5779e-02, -2.7614e-03, -2.5383e-03, -3.9343e-03],
[ 6.0877e-03, 8.4987e-03, 1.3891e-02, -1.4258e-03, -8.7571e-03,
-4.2930e-03, -9.5155e-04, -1.2515e-03, -2.7573e-03, -1.8363e-03,
3.6556e-03, 1.8146e-03, -1.5488e-02, 4.5685e-04, -1.0984e-02,
-3.8873e-03, -2.3433e-03, -8.7609e-03, 4.1699e-03, 9.4601e-03,
3.1651e-04, 8.2427e-03, 2.3207e-03, 4.8457e-03, -1.5066e-02,
6.2566e-03, 1.7090e-02, 9.2237e-03, 3.8094e-03, -3.4015e-03],
[ 1.0200e-02, 1.2802e-02, 6.2122e-03, -1.1477e-02, 2.2691e-03,
-4.6306e-04, 2.8861e-06, 9.0622e-03, -2.6586e-02, -7.3941e-04,
6.5375e-07, -1.5064e-03, 2.4399e-02, -1.4029e-03, -4.1262e-03,
1.2517e-02, -6.1972e-03, -8.5488e-03, -5.8397e-03, 1.9754e-03,
4.5322e-03, 4.9351e-03, -8.4586e-03, -1.0668e-02, -4.0087e-04,
7.3720e-03, -6.5349e-03, -1.7660e-03, -9.3602e-03, 1.5721e-03],
[ 1.6188e-02, -2.0150e-03, -1.6816e-02, -1.6911e-02, 6.5990e-03,
-1.6605e-02, -9.2304e-03, -8.1321e-03, 2.1806e-02, -7.1569e-03,
-1.0543e-02, -6.5349e-03, 2.9192e-03, -1.5003e-02, -6.7353e-03,
3.0836e-03, -1.0745e-04, -6.8194e-03, 1.0827e-02, -1.0363e-02,
-1.2346e-02, -7.0676e-03, 1.2861e-02, -6.6271e-03, -5.7568e-03,
3.4977e-03, -5.9773e-03, 3.5096e-03, 1.0743e-02, 8.9650e-03],
[ 2.9357e-03, 7.0599e-03, -8.3739e-03, -3.4416e-03, 2.7688e-02,
-1.0271e-02, -1.3291e-02, -9.5348e-03, 9.4741e-03, 5.1175e-03,
9.3990e-03, -2.4245e-03, -7.4773e-03, -1.8540e-03, -1.7982e-02,
-9.4906e-03, 1.2980e-02, 1.3263e-02, 7.2149e-04, 2.5788e-04,
-2.8871e-04, -1.0207e-02, 4.8015e-03, -2.5008e-03, -1.0212e-02,
-1.6521e-02, -1.0582e-02, 2.5365e-03, 2.0865e-03, 4.9698e-03],
[ 4.8506e-03, 1.2074e-02, -1.1068e-02, 6.3328e-03, -7.2685e-03,
8.6962e-03, 8.3850e-03, 3.9330e-04, -1.4435e-02, 2.6184e-03,
-9.1552e-03, -9.1968e-03, -7.5484e-03, 1.7064e-02, -1.0926e-02,
5.8421e-03, 3.2155e-03, -1.7801e-02, -7.1032e-03, 1.5225e-02,
4.3300e-03, -1.1750e-02, 2.3955e-03, 5.8431e-03, -3.3767e-03,
8.4297e-03, 5.9677e-03, -5.3130e-03, -6.2581e-03, 7.2684e-04],
[ 1.0966e-03, 5.1526e-03, 1.2423e-02, 2.7335e-04, -2.2481e-03,
-6.6638e-03, 1.0812e-02, -5.5681e-03, -3.6775e-03, 1.7190e-02,
-8.9243e-03, 1.0826e-02, 2.1140e-02, -3.1026e-03, -7.8873e-03,
1.1538e-02, -1.9044e-03, 1.0823e-02, 9.9094e-03, 2.5864e-03,
-1.1288e-02, 7.8564e-03, 1.6430e-02, -5.0827e-03, -1.3557e-03,
1.5250e-02, -7.7385e-03, -1.2449e-02, 1.0931e-02, 4.5603e-03],
[ 1.0957e-02, -5.0123e-03, -9.2494e-03, -5.7311e-03, 1.0838e-02,
3.2123e-03, -2.0542e-03, -7.1564e-03, -1.0734e-03, 1.9866e-02,
1.6477e-02, -5.1313e-04, 3.8111e-03, -7.6825e-03, -6.2488e-03,
5.7776e-03, 6.2000e-04, -1.7886e-02, -1.0113e-02, 1.1926e-03,
-1.7908e-02, -9.1477e-03, -4.8008e-03, -4.7199e-03, 1.5774e-02,
2.0051e-02, -1.7119e-02, 3.6467e-03, -2.3746e-03, -3.1700e-03],
[-6.5938e-03, -8.9639e-03, 7.4318e-03, -7.2729e-03, -7.1252e-03,
-2.4921e-03, 3.9307e-03, 3.5104e-03, 1.1045e-02, -5.6922e-03,
4.3981e-03, 1.8079e-02, -4.0234e-03, -9.3731e-03, 2.8033e-03,
1.2702e-02, -6.7108e-03, 1.1401e-03, 4.4380e-03, 3.1948e-02,
-5.5021e-03, 1.4521e-03, -1.3206e-02, -2.4222e-03, 4.7926e-04,
-1.1533e-02, -5.7481e-03, 1.2893e-02, 1.1545e-02, 5.2813e-03],
[ 1.5414e-02, -6.9193e-03, 1.6852e-02, -1.0622e-02, 1.4744e-02,
1.0154e-03, 1.3041e-02, 1.7373e-02, -3.9367e-03, 8.8151e-03,
-1.1236e-02, -1.9134e-02, -2.9457e-04, -1.1307e-02, 5.1315e-04,
1.5109e-02, 9.7158e-03, -2.8689e-03, 2.3455e-03, 1.0067e-02,
-1.0586e-02, -8.3479e-03, 1.1931e-02, 8.1437e-04, -7.5972e-03,
-1.2458e-03, 6.7771e-03, 1.0820e-02, 1.9163e-02, 1.5708e-02],
[-9.7202e-03, 8.6719e-04, 2.0114e-02, -1.7551e-03, 1.0009e-02,
2.3814e-06, -4.0599e-03, -1.6323e-03, -5.8922e-04, -1.9179e-02,
-6.1471e-03, 3.2786e-03, -8.2910e-03, 1.5821e-02, -6.7673e-04,
6.7152e-03, -9.1754e-03, -1.1328e-02, -2.6055e-03, 5.0403e-03,
1.2042e-03, -3.8331e-03, -1.4951e-02, -1.3479e-02, 1.1225e-03,
8.5821e-04, -6.5184e-03, -1.4773e-02, -5.9137e-03, 3.5191e-04],
[ 1.8287e-02, 6.3982e-03, 9.5072e-05, -8.5228e-03, 3.8553e-03,
-1.4717e-02, -2.2549e-02, -9.0621e-04, -1.2570e-02, 1.1718e-02,
-6.6394e-03, 4.0225e-03, 1.1599e-02, -1.2683e-02, -1.6592e-02,
6.4644e-03, -3.6911e-04, -1.5643e-02, 5.2307e-03, -6.2249e-03,
-7.1715e-03, -2.3126e-03, -8.4576e-04, 3.8474e-03, 1.0998e-02,
-4.6572e-03, 5.6035e-03, 3.4481e-02, 1.2175e-02, 5.7834e-03],
[ 8.1615e-03, 1.1394e-02, 6.7657e-03, 6.2123e-03, -2.5495e-02,
8.0885e-03, -1.4886e-02, 9.0453e-03, 7.6994e-04, 3.8311e-03,
3.3331e-03, -1.1154e-03, -1.5299e-03, -1.0464e-04, -1.7200e-02,
1.7913e-03, 6.0317e-03, -9.6406e-03, -3.2287e-03, -8.2056e-03,
-3.6062e-03, 1.4878e-02, 2.8893e-04, 5.8158e-03, 2.6536e-03,
-2.4103e-02, -1.4993e-02, -1.4791e-03, -5.8795e-06, 3.5734e-03],
[-1.9504e-02, -5.2369e-03, 4.3967e-03, -8.9265e-03, -5.5365e-03,
-3.4119e-03, -9.1115e-03, 1.0674e-02, 2.1772e-03, 1.2874e-02,
4.4263e-03, 7.4545e-03, -2.3560e-03, 2.8753e-03, -1.1261e-02,
2.3133e-02, 4.8261e-03, 8.4998e-03, -9.4744e-03, 2.2099e-02,
-4.8882e-03, -5.5454e-03, 1.4507e-02, -6.0991e-03, 1.9609e-02,
1.0477e-03, -1.3490e-02, 1.1814e-03, -8.1444e-03, 1.8736e-02],
[ 2.4450e-02, -1.7491e-03, -5.5336e-03, 3.9693e-03, -9.6735e-04,
3.7527e-03, 6.2744e-03, -2.6313e-03, 1.5444e-02, -4.2243e-03,
-8.7079e-03, 2.0071e-02, 1.2334e-03, -8.3905e-03, -1.6589e-02,
3.5387e-03, -3.9648e-03, 8.2130e-04, 1.4307e-02, -4.3372e-03,
3.1793e-04, -7.4976e-03, -7.1900e-03, 1.2343e-03, -6.7502e-03,
-1.1686e-02, 7.9501e-03, -2.0728e-02, 5.3000e-03, 2.5626e-02]])
2.linear.weight tensor([[-5.5996e-03, -1.7881e-02, 2.9029e-03, -3.8814e-04, 4.6640e-03,
1.4729e-03, -8.7504e-04, 1.7051e-03, -9.9808e-03, -5.0457e-03,
1.1433e-02, 2.6712e-03, 6.0539e-04, -1.9261e-02, 3.1117e-03,
4.8112e-03, -7.1367e-03, -8.3227e-03, -8.6805e-03, -1.3699e-02],
[ 9.2634e-03, -4.6123e-04, 8.3162e-05, 4.1494e-03, 5.1273e-03,
-1.6681e-03, 1.4864e-02, -3.0277e-04, 1.7692e-03, -9.4125e-03,
3.1612e-03, -1.3875e-02, -1.1496e-03, 1.5235e-02, -1.0851e-03,
8.6537e-03, 9.0799e-03, -2.8304e-03, 7.5469e-03, -7.0789e-03],
[-5.3225e-03, -8.5825e-03, 5.8378e-03, 1.1406e-03, -3.2881e-03,
7.9792e-03, 1.5086e-02, -8.4687e-03, -1.3789e-02, -1.8288e-02,
1.3957e-03, 2.2536e-03, -3.3606e-03, 3.2420e-03, -9.3056e-03,
1.5156e-02, 6.4650e-03, -1.3013e-03, 1.0178e-02, 1.3913e-02],
[-1.2846e-02, 4.8990e-04, 7.1575e-03, 1.3793e-03, -5.9738e-03,
3.9912e-03, 2.5881e-03, -1.5927e-02, -9.1614e-03, -1.6528e-02,
1.1311e-02, -4.5142e-03, 4.0138e-03, -2.2480e-02, -3.1659e-03,
-1.8865e-02, 1.4666e-02, 1.0195e-02, -9.0439e-03, 3.8127e-03],
[ 1.2961e-02, 7.7461e-03, 1.1123e-02, 8.3137e-03, -1.2393e-02,
-3.1891e-03, 4.7997e-03, 1.3994e-02, 6.5888e-03, 1.6503e-02,
-8.6732e-05, 1.0505e-02, -5.6654e-03, -5.1383e-03, -6.2080e-03,
1.7060e-03, 1.0218e-02, -5.8861e-03, 1.1582e-02, 2.7832e-03],
[ 7.2542e-03, 2.7732e-03, 6.6612e-03, 3.4965e-03, -4.0417e-03,
-7.7164e-03, -1.9723e-02, -1.8491e-03, 9.4730e-03, -1.7118e-04,
1.1873e-02, -7.9715e-03, 1.0753e-03, 1.0804e-02, 1.3700e-02,
1.6791e-02, -1.4478e-02, 1.7035e-02, 9.3034e-04, 6.3716e-03],
[-5.1573e-03, -1.3071e-03, 1.5302e-02, -6.1851e-03, 6.7677e-04,
1.4603e-02, 7.0722e-03, 2.6819e-03, 1.5753e-02, -3.4153e-03,
-2.0578e-03, -8.0549e-03, -1.2534e-02, 7.2134e-03, -1.6519e-02,
7.4687e-04, 4.4496e-03, -2.0857e-03, -1.4627e-02, 7.0603e-03],
[-3.0296e-04, 1.3288e-03, -7.3660e-03, -9.3445e-04, 7.9644e-03,
4.3456e-03, -1.3778e-02, 4.5655e-03, -1.2824e-02, -1.1208e-02,
1.8958e-02, 2.4357e-03, -2.1795e-02, 5.2504e-03, 8.8114e-04,
2.6217e-03, -5.6610e-03, 3.5770e-04, 9.2574e-03, 7.0462e-04],
[ 5.5887e-03, -2.3937e-02, 1.3323e-02, 1.2788e-02, -1.2492e-02,
-8.2638e-03, 3.7073e-04, -6.0787e-03, 4.5608e-03, -2.0334e-02,
-1.5575e-02, -7.6481e-03, 1.2293e-02, -4.0633e-03, -4.2219e-03,
-8.4397e-03, -7.8156e-03, 1.4628e-02, 8.0971e-04, 4.0196e-04],
[-1.3949e-02, 3.2856e-03, 3.0923e-03, -1.2260e-04, -1.1090e-02,
-1.0921e-02, 1.2069e-02, -1.4299e-02, 6.6826e-03, -5.0484e-03,
-1.6159e-02, 5.4746e-03, 4.6988e-03, -2.0970e-02, -1.2481e-03,
-9.0344e-03, -1.4410e-02, -3.9734e-03, 4.2471e-03, 1.3298e-02],
[-2.9720e-03, 1.4633e-02, 3.5870e-03, 2.0081e-03, -8.4944e-04,
-1.0992e-02, 1.3077e-02, -1.0255e-03, 4.1050e-05, -9.8274e-03,
-4.1936e-04, 1.1226e-02, -1.2842e-02, 1.6335e-02, -8.9160e-03,
4.1783e-03, 3.9959e-03, -4.7716e-03, -3.1262e-03, -1.0138e-02],
[ 2.0540e-02, -2.0969e-03, -3.8366e-04, -2.7185e-03, -3.2523e-03,
7.9718e-03, 8.1142e-03, 3.6300e-03, -7.4875e-03, 1.3166e-02,
1.3331e-03, 1.0332e-02, 1.1699e-02, 1.0125e-02, 1.0621e-02,
-1.6513e-02, -3.5310e-03, -1.9825e-02, -8.0858e-03, -3.1577e-03],
[ 9.9199e-03, -1.4978e-03, -2.9754e-03, -1.2934e-02, -1.4795e-02,
-1.6737e-02, -6.4125e-03, -1.3493e-03, -6.9752e-03, 1.2560e-02,
-4.3187e-03, -1.4194e-02, 3.6981e-04, -2.8193e-03, 8.1851e-03,
-1.1203e-02, 1.9439e-02, 1.0795e-02, -8.2543e-03, 2.3599e-03],
[-1.1758e-02, 9.5521e-03, 6.3462e-04, 3.5751e-03, 1.3527e-02,
9.1637e-03, -4.8209e-03, -1.0592e-02, 2.0195e-03, 1.6119e-02,
1.1654e-02, -1.0050e-02, -7.9837e-03, 1.9223e-02, 2.3013e-03,
1.4753e-03, 6.0071e-03, 4.8298e-04, 6.6951e-03, -1.2826e-02],
[-2.0467e-02, -4.2140e-03, 6.8931e-03, -5.0351e-03, -1.0566e-02,
-6.9593e-03, 4.7499e-03, 8.2981e-03, 2.5675e-02, 1.5285e-02,
-9.5692e-03, -4.5086e-03, -5.6572e-03, -1.1610e-03, -1.7814e-03,
1.2086e-02, 9.0804e-03, -9.1113e-03, -1.9917e-02, -1.6530e-02],
[-8.3130e-03, 1.3344e-02, -1.0923e-02, -2.2851e-02, 6.8411e-03,
2.1356e-03, 8.5286e-03, -5.2713e-03, -1.2193e-02, -1.1948e-03,
3.3531e-04, -4.8021e-03, 5.7277e-03, 1.1665e-02, -1.0539e-04,
7.7064e-03, 1.4751e-02, -4.4743e-03, 1.4951e-02, 1.2392e-02],
[-1.3800e-02, 1.3546e-02, -6.8238e-03, 1.4270e-02, -1.2051e-02,
5.9353e-03, 2.4619e-03, 1.5746e-03, 7.2940e-03, -9.0082e-04,
-5.7518e-03, -2.0381e-02, -7.1935e-03, 2.1555e-03, -5.1282e-03,
-1.9255e-02, -7.8506e-03, -1.4166e-02, 3.6891e-03, -6.8586e-03],
[-1.3660e-02, 2.2254e-02, 1.7642e-02, -3.0450e-03, 1.3232e-02,
-1.0315e-02, 1.0142e-02, 3.5027e-04, -5.0814e-04, -2.6084e-03,
2.5116e-02, 5.9374e-03, 5.4630e-03, 1.1566e-02, -8.3270e-03,
-7.3726e-03, 1.1869e-02, -1.8963e-02, 5.7513e-03, -8.7186e-03],
[-2.0966e-03, -1.7567e-03, 1.4226e-03, -6.6999e-03, 4.1837e-03,
-7.0836e-04, 5.6019e-03, -6.0745e-03, 7.3432e-03, 6.0721e-03,
5.1656e-03, -1.8158e-03, -8.8598e-03, 2.4378e-02, 4.7904e-03,
7.9625e-03, 5.9612e-03, -2.0552e-03, -4.8240e-03, -1.4750e-02],
[ 7.3819e-03, -6.9134e-03, 8.1345e-03, -1.4451e-02, 7.6821e-03,
-1.5124e-03, 3.8898e-03, 1.2265e-03, -9.0760e-03, -1.5998e-02,
2.7463e-04, -6.5673e-03, -7.7155e-04, -2.0952e-02, -2.3035e-03,
-1.2220e-03, 1.2617e-02, 5.1720e-03, -1.7524e-02, 5.6977e-03]])
for name, param in net.named_parameters():
if 'bias' in name:
init.constant_(param, val=0)
print(name, param.data)
0.net.0.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.])
1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
2.linear.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
自定义初始化方法
有时候我们需要的初始化方法并没有在init模块中提供。这时,可以实现一个初始化方法,从而能够像使用其他初始化方法那样使用它。在这之前我们先来看看PyTorch是怎么实现这些初始化方法的,例如torch.nn.init.normal_:
def normal_(tensor, mean=0, std=1):
with torch.no_grad():
return tensor.normal_(mean, std)
可以看到这就是一个inplace改变Tensor值的函数,而且这个过程是不记录梯度的。 类似的我们来实现一个自定义的初始化方法。在下面的例子里,我们令权重有一半概率初始化为0,有另一半概率初始化为[−10,−5][−10,−5]和[5,10][5,10]两个区间里均匀分布的随机数。
def init_weight_(tensor):
with torch.no_grad():
tensor.uniform_(-10, 10)
tensor *= (tensor.abs() >= 5).float()
for name, param in net.named_parameters():
if 'weight' in name:
init_weight_(param)
print(name, param.data)
0.net.0.weight tensor([[-0.0000, -6.2573, -9.6670, ..., -6.2055, -0.0000, 0.0000],
[ 0.0000, -6.1450, -0.0000, ..., 9.4352, 8.5185, 0.0000],
[ 0.0000, 0.0000, -0.0000, ..., 0.0000, -0.0000, -5.4052],
...,
[-0.0000, -9.5803, -5.8481, ..., -0.0000, -0.0000, -0.0000],
[-9.8709, 0.0000, -0.0000, ..., 0.0000, 7.7648, -7.7570],
[-8.7503, -9.4466, 8.8122, ..., 5.5594, -0.0000, 9.3528]])
1.weight tensor([[-5.3031, -7.1580, -9.9964, -6.4769, 0.0000, -5.8270, 5.7411, -5.8713,
-0.0000, -0.0000, -0.0000, 5.3456, 0.0000, 0.0000, -0.0000, 6.5026,
0.0000, -0.0000, 0.0000, 0.0000, -6.3099, 0.0000, -0.0000, 0.0000,
-7.5503, -0.0000, 9.9910, -0.0000, -0.0000, 6.3966],
[ 5.5881, 0.0000, -0.0000, -8.1178, 9.6688, -0.0000, 0.0000, 8.6963,
0.0000, 8.0869, 0.0000, 6.0899, 8.3442, -6.8183, -8.3160, -0.0000,
-0.0000, -7.4499, -6.6586, -0.0000, 0.0000, -6.7071, -0.0000, 7.4049,
0.0000, -9.0628, -0.0000, 7.3452, 0.0000, -8.3752],
[-6.3758, 0.0000, 0.0000, 9.1926, 7.5408, 0.0000, -9.9305, 0.0000,
-0.0000, -9.0137, 0.0000, 8.7384, -0.0000, -6.3803, -0.0000, 0.0000,
-8.8573, 8.8939, 0.0000, 6.2515, 5.3137, 7.6621, -0.0000, -8.0161,
-8.4190, 6.1407, -0.0000, -6.5521, 5.4117, -0.0000],
[ 0.0000, -9.7822, 0.0000, -0.0000, 0.0000, -9.9513, 0.0000, -0.0000,
9.5243, 6.7586, -6.7599, -7.2046, 5.7396, 0.0000, -7.4426, 6.5923,
5.4310, 0.0000, 9.5645, -0.0000, 9.5840, -7.3611, -0.0000, 5.2647,
-5.9115, -0.0000, -0.0000, -0.0000, -5.2489, -8.3718],
[ 5.0601, 8.8166, 0.0000, 6.9072, -8.9312, 0.0000, 7.9705, -9.0079,
6.0310, -9.9306, 0.0000, 5.2559, 0.0000, -0.0000, 0.0000, 6.6453,
-8.2131, -0.0000, -9.7696, -0.0000, -6.1296, -6.1960, 0.0000, -7.1193,
6.0872, 8.9191, -7.5285, 8.8542, 0.0000, -0.0000],
[ 5.7640, -8.2749, 0.0000, -0.0000, -0.0000, 8.4516, -0.0000, 8.1646,
-7.5594, -0.0000, -7.6956, 0.0000, 0.0000, 8.4151, 0.0000, -0.0000,
-9.5694, -6.6565, 5.6100, -7.5257, -0.0000, 7.6025, 6.7463, -9.1704,
-0.0000, 0.0000, 8.9362, 7.7404, -9.6161, 0.0000],
[-5.9001, 0.0000, -0.0000, -7.5054, 0.0000, -7.8240, 7.8767, 7.6504,
9.0638, 7.9266, 7.0275, -0.0000, -6.6694, 5.2457, 9.6615, 0.0000,
5.6970, -7.9317, -0.0000, 0.0000, 6.3138, -0.0000, -0.0000, 6.8197,
-0.0000, -5.5665, 5.6393, 6.9409, 8.4405, -6.9706],
[ 0.0000, -6.5044, -5.4181, 0.0000, -8.2382, 6.3029, -0.0000, -0.0000,
0.0000, -0.0000, -8.7580, 6.9954, 8.0429, -9.6811, 0.0000, 7.1997,
9.6289, -6.6778, 8.3476, 0.0000, 0.0000, -5.1837, -7.5868, 8.8954,
0.0000, 9.4241, -5.2790, 9.6536, -6.5575, -7.8887],
[ 8.6683, -9.6715, 0.0000, 5.7765, 6.6276, 8.2964, -0.0000, -6.1496,
8.9688, -0.0000, -0.0000, 8.1510, -0.0000, 6.2618, 0.0000, -7.1139,
-9.1306, 8.7948, -8.2000, 0.0000, -7.7904, -0.0000, 8.5050, -7.8431,
-0.0000, -0.0000, 7.9535, -6.2632, -6.1001, -6.1616],
[-8.8006, -8.2721, -6.8235, 0.0000, -9.7454, 0.0000, -0.0000, 0.0000,
6.3692, 9.8793, 0.0000, 0.0000, -7.2012, 0.0000, 0.0000, 0.0000,
-0.0000, 0.0000, 8.0920, -0.0000, -9.2214, -0.0000, -8.5031, -8.3191,
-0.0000, 7.7950, -0.0000, 0.0000, 0.0000, -7.7035],
[ 7.3713, 6.2752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 9.8827,
-9.9048, -0.0000, 0.0000, 8.6781, 7.9049, -0.0000, -0.0000, -0.0000,
0.0000, -0.0000, 0.0000, 5.6662, -5.7249, 7.5572, 7.9236, -8.0282,
0.0000, -0.0000, -8.1768, 0.0000, -0.0000, -0.0000],
[-9.1349, 8.1316, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000,
6.6178, -0.0000, -0.0000, 0.0000, 7.0107, 8.4189, -0.0000, 5.0933,
9.8502, -0.0000, -9.1036, 0.0000, -5.6974, 0.0000, -9.8530, -8.0555,
0.0000, -5.2797, -7.7184, -0.0000, -5.4591, 8.9083],
[-0.0000, 7.5084, 0.0000, -7.2532, 0.0000, -0.0000, 0.0000, 8.1774,
-9.4518, -5.2770, -0.0000, -0.0000, 5.0882, 7.4918, -8.7120, -0.0000,
-7.0417, 0.0000, -0.0000, 5.1748, 0.0000, -0.0000, -0.0000, -9.3591,
0.0000, -8.2938, -8.9266, -0.0000, 0.0000, 0.0000],
[ 9.5399, 0.0000, 7.4996, -0.0000, 9.8394, 0.0000, 0.0000, -0.0000,
0.0000, -0.0000, 0.0000, 0.0000, -5.3790, -9.6425, 9.8764, 0.0000,
7.0963, -8.9114, -0.0000, 9.0510, 5.4095, -5.0019, 0.0000, -7.0848,
-0.0000, 0.0000, 0.0000, -0.0000, 0.0000, -9.8849],
[ 8.1293, 0.0000, -6.7563, 0.0000, 7.6843, 8.4486, 6.6997, 7.9062,
-5.6233, 0.0000, 6.2966, 7.4003, 0.0000, -0.0000, 8.2519, 0.0000,
8.9715, -8.4958, -7.0926, -0.0000, -0.0000, -5.9202, 6.4348, 8.4848,
0.0000, -7.7669, -7.1864, 0.0000, 9.1471, -0.0000],
[ 0.0000, 8.9136, -6.4488, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000,
9.3667, 7.9826, 0.0000, 0.0000, -6.9993, 0.0000, 9.6924, 0.0000,
8.0840, 9.4185, 0.0000, 6.2945, 7.2634, -0.0000, -6.5016, -0.0000,
0.0000, -0.0000, 6.4698, -0.0000, 5.2106, -0.0000],
[-0.0000, 6.1383, -9.1950, -5.4267, -0.0000, -8.8607, 9.7967, 0.0000,
0.0000, 0.0000, 5.7624, -9.8647, 7.0602, -9.4853, -0.0000, 9.4742,
-5.6148, -6.2168, 0.0000, 9.9788, -9.1114, 0.0000, 0.0000, -0.0000,
6.7838, 0.0000, 0.0000, -0.0000, 6.0502, -5.0156],
[ 8.1110, -0.0000, 6.7124, -6.7490, 5.9337, -8.4773, -0.0000, -7.2125,
0.0000, -0.0000, 5.2543, 0.0000, 0.0000, 9.5925, 0.0000, -0.0000,
-0.0000, -0.0000, -8.4826, 0.0000, -0.0000, -6.4074, 0.0000, -0.0000,
7.2494, -8.8358, 7.8122, 8.8839, 7.9619, 7.8244],
[ 0.0000, -6.1305, 0.0000, 6.7742, 5.9301, -0.0000, 0.0000, -9.8249,
7.4770, 0.0000, 6.4683, -0.0000, 0.0000, -0.0000, -7.5854, -0.0000,
-9.8200, -8.4659, -0.0000, -0.0000, 0.0000, 8.8622, 8.4771, 5.8597,
0.0000, 0.0000, 6.5139, 0.0000, -6.5754, -0.0000],
[-0.0000, 8.0858, -0.0000, 0.0000, -0.0000, 9.6130, -0.0000, -5.2527,
-9.6432, -5.0170, 5.0056, 7.9539, 0.0000, -0.0000, -8.1660, -0.0000,
0.0000, -0.0000, 0.0000, -7.4998, 0.0000, -0.0000, -0.0000, 9.5478,
8.9482, -0.0000, 0.0000, -6.8827, 8.9620, -9.7837]])
2.linear.weight tensor([[ 0.0000, -6.9812, 6.5521, -0.0000, -9.5926, -0.0000, -0.0000, -5.0386,
6.5204, 0.0000, 8.5534, 9.0793, 0.0000, -0.0000, -0.0000, -5.9300,
-5.9106, -0.0000, 8.0045, 7.3943],
[-5.1885, -0.0000, -0.0000, 5.7570, 0.0000, -0.0000, -8.7268, 9.2548,
0.0000, -7.7151, -6.8379, -8.5315, -0.0000, 0.0000, 6.4908, 0.0000,
-0.0000, -0.0000, -5.9161, -5.8564],
[ 0.0000, -9.4816, -9.4379, -0.0000, -9.3362, -0.0000, 0.0000, -8.2363,
-0.0000, 8.0405, 8.9864, 0.0000, -9.2516, -6.2904, 5.7813, 5.3719,
-0.0000, -8.8506, 0.0000, -9.7931],
[ 9.4502, -9.2087, 7.3578, -0.0000, -0.0000, -5.6518, 6.7095, 0.0000,
-0.0000, -0.0000, 5.9396, 7.8704, 0.0000, -0.0000, -7.2353, 0.0000,
-0.0000, 7.0156, 0.0000, -5.0929],
[ 0.0000, -9.9551, -0.0000, 6.6445, 7.0596, -0.0000, -0.0000, -7.0890,
-0.0000, -0.0000, 0.0000, 8.3706, 0.0000, -7.1886, 0.0000, -8.4613,
-0.0000, -7.1553, 0.0000, -5.9845],
[-0.0000, 7.2888, -0.0000, -0.0000, 0.0000, -9.2040, 5.5700, -7.6277,
0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -8.3413, 5.5270, 0.0000,
7.4749, -0.0000, -7.7825, -6.2418],
[ 0.0000, -0.0000, -9.5633, 0.0000, 7.8421, 0.0000, -0.0000, -0.0000,
0.0000, 7.5993, 7.9292, 7.3422, 9.5524, -5.4055, -0.0000, 7.3166,
-0.0000, 0.0000, 9.8042, 6.4248],
[ 0.0000, -0.0000, 9.5345, 9.1690, -0.0000, -0.0000, 5.9536, -6.3916,
0.0000, 8.0737, -0.0000, 9.2005, -0.0000, 0.0000, -0.0000, 0.0000,
0.0000, 0.0000, -0.0000, 5.8255],
[-0.0000, -0.0000, 0.0000, 0.0000, -8.6615, 0.0000, -6.7372, -8.7178,
-7.8605, 8.0300, -9.7292, -0.0000, -6.1143, 9.4508, 0.0000, -0.0000,
0.0000, 8.0177, -7.3347, 0.0000],
[ 6.3698, 0.0000, -0.0000, -0.0000, -7.8566, -0.0000, 0.0000, -0.0000,
9.4590, -8.1037, 0.0000, -6.8476, -5.9440, -0.0000, 0.0000, -7.6990,
0.0000, -8.9955, -0.0000, -0.0000],
[-0.0000, 7.5792, 0.0000, -0.0000, 0.0000, -9.8703, 9.7579, -8.2181,
-0.0000, -6.6664, 9.2518, -0.0000, 0.0000, 6.6096, -9.9753, 0.0000,
7.4476, -8.6817, -7.4076, -6.4940],
[-8.1521, 9.0149, 6.9163, -0.0000, 0.0000, -0.0000, 7.4414, -0.0000,
-5.4136, 0.0000, 0.0000, 0.0000, 5.6016, 8.2639, 0.0000, -0.0000,
8.3907, 5.6836, -7.3025, 6.4559],
[-0.0000, -0.0000, -6.3013, 0.0000, 0.0000, 0.0000, 7.3791, -0.0000,
9.9710, 0.0000, -6.0639, 0.0000, -5.8077, 5.1654, 0.0000, 7.9023,
-0.0000, -0.0000, -0.0000, -8.7207],
[ 5.5755, 0.0000, -6.9844, 7.6312, -0.0000, 0.0000, 8.0948, 0.0000,
0.0000, 0.0000, -0.0000, 0.0000, -0.0000, -0.0000, -0.0000, -7.1881,
-9.8673, -0.0000, 6.4959, 5.2270],
[ 0.0000, 0.0000, -9.7369, -8.0855, 0.0000, 7.7631, 0.0000, -6.6331,
-0.0000, 0.0000, -0.0000, -8.9024, -0.0000, 7.6198, -0.0000, 5.1892,
0.0000, 8.1393, 0.0000, 9.0153],
[-0.0000, -7.5369, 7.0477, -0.0000, 0.0000, 0.0000, 0.0000, -0.0000,
-6.6719, -9.6702, -0.0000, 6.2202, 5.0695, -0.0000, 0.0000, -5.9180,
-0.0000, 0.0000, 9.3806, 5.2014],
[-0.0000, 5.1159, -7.1775, -0.0000, -5.8047, 7.2177, -7.7319, -0.0000,
-0.0000, -9.0223, -0.0000, 0.0000, 0.0000, -0.0000, 8.1462, -8.1729,
6.9001, -0.0000, -0.0000, 8.9680],
[-0.0000, 0.0000, -0.0000, 0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
8.0092, -8.0028, 0.0000, -0.0000, 7.1211, -0.0000, 0.0000, 0.0000,
9.3772, -0.0000, 0.0000, -0.0000],
[ 0.0000, 5.9549, 5.3347, -0.0000, 0.0000, 6.8372, -0.0000, 7.0890,
6.8374, 6.8444, 5.3985, 5.1464, -0.0000, -0.0000, 8.7471, 6.0348,
-0.0000, 0.0000, 6.7883, 8.7566],
[ 7.5425, 0.0000, -0.0000, -6.2917, 9.6904, -6.8175, -9.9985, 0.0000,
5.8819, 8.6874, 0.0000, 6.9018, 0.0000, -7.0253, -0.0000, 0.0000,
7.5909, 0.0000, 0.0000, -0.0000]])
参考2.3.2节,我们还可以通过改变这些参数的data来改写模型参数值同时不会影响梯度:
for name, param in net.named_parameters():
if 'bias' in name:
param.data += 1
print(name, param.data)
0.net.0.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
1.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])
2.linear.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])
共享模型参数
在有些情况下,我们希望在多个层之间共享模型参数。4.1.3节提到了如何共享模型参数: Module类的forward函数里多次调用同一个层。此外,如果我们传入Sequential的模块是同一个Module实例的话参数也是共享的,下面来看一个例子:
linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear)
print(net)
for name, param in net.named_parameters():
init.constant_(param, val=3)
print(name, param.data)
Sequential(
(0): Linear(in_features=1, out_features=1, bias=False)
(1): Linear(in_features=1, out_features=1, bias=False)
)
0.weight tensor([[3.]])
在内存中,这两个线性层其实一个对象:
print(id(net[0]) == id(net[1]))
print(id(net[0].weight) == id(net[1].weight))
True
True
因为模型参数里包含了梯度,所以在反向传播计算时,这些共享的参数的梯度是累加的:
x = torch.ones(1, 1)
y = net(x).sum()
print(y)
y.backward()
print(net[0].weight.grad) # 单次梯度是3,两次所以就是6
tensor(9., grad_fn=<SumBackward0>)
tensor([[6.]])
自定义层
不含模型参数的自定义层
我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和4.1节(模型构造)中介绍的使用Module类构造模型类似。下面的CenteredLayer类通过继承Module类自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。这个层里不含模型参数。
import torch
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self,**kwargs):
super(CenteredLayer,self).__init__(**kwargs)
def forward(self,x):
return x-x.mean()
layer = CenteredLayer()
layer(torch.tensor([1,2,3,4,5],dtype = torch.float))
tensor([-2., -1., 0., 1., 2.])
net = nn.Sequential(nn.Linear(8,128,CenteredLayer()))
y = net(torch.rand(4,8))
y.mean().item()
0.005987975746393204
含模型参数的自定义层
在4.2节(模型参数的访问、初始化和共享)中介绍了Parameter类其实是Tensor的子类,如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了像4.2.1节那样直接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。
ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。
class MyDense(nn.Module):
def __init__(self):
super(MyDense,self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4,1)))
def forward(self,x):
for i in range(len(self.params)):
x = torch.mm(x,self.params[i])
return x
net = MyDense()
print(net)
MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
而ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等,可参考官方文档。
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
这样就可以根据传入的键值来进行不同的前向传播:
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
tensor([[ 5.4923, 0.4334, -1.8005, -0.0718]], grad_fn=<MmBackward>)
tensor([[2.9095]], grad_fn=<MmBackward>)
tensor([[ 1.0343, -3.6477]], grad_fn=<MmBackward>)
我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。
net = nn.Sequential(
MyDictDense(),
MyDense(),
)
print(net)
print(net(x))
Sequential(
(0): MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
(1): MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
)
tensor([[-8.2394]], grad_fn=<MmBackward>)
读取与存储
读写Tensor
我们可以直接使用save函数和load函数分别存储和读取Tensor。save使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存。
下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。
import torch
from torch import nn
x= torch.ones(3)
print(x)
torch.save(x,'x.pt')
tensor([1., 1., 1.])
x2=torch.load('x.pt')
x2
tensor([1., 1., 1.])
还可以存储一个Tensor列表并读回内存。
y = torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
存储并读取一个从字符串映射到Tensor的字典。
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
读写模型
在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[ 0.3508, -0.0527, 0.5097],
[ 0.1291, 0.4404, 0.2258]])),
('hidden.bias', tensor([0.0332, 0.4611])),
('output.weight', tensor([[0.1315, 0.0635]])),
('output.bias', tensor([-0.4548]))])
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0.9,
'nesterov': False,
'params': [2501567222120, 2501567222480, 2501566216088, 2501565231320],
'weight_decay': 0}],
'state': {}}
仅保存和加载模型参数(state_dict)
torch.save(model.state_dict(),PATH) # 推荐的文件后缀名是pt或pth
#加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-55-60ddd59ce579> in <module>()
----> 1 torch.save(model.state_dict(),PATH) # 推荐的文件后缀名是pt或pth
2
3 #加载
4 model = TheModelClass(*args, **kwargs)
5 model.load_state_dict(torch.load(PATH))
NameError: name 'model' is not defined
保存和加载整个模型
torch.save(model, PATH)
#加载
model = torch.load(PATH)
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
tensor([[True],
[True]])
模型的GPU计算
用torch.cuda.is_available()查看GPU是否可用
torch.cuda.is_available()
False
查看GPU数量:
torch.cuda.device_count()
0
查看当前GPU索引号,索引号从0开始:
torch.cuda.current_device() # 输出 0
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-59-55c583da45aa> in <module>()
----> 1 torch.cuda.get_device_name(0)
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_name(device)
291 if :attr:`device` is ``None`` (default).
292 """
--> 293 return get_device_properties(device).name
294
295
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_properties(device)
313 def get_device_properties(device):
314 if not _initialized:
--> 315 init() # will define _get_device_properties and _CudaDeviceProperties
316 device = _get_device_index(device, optional=True)
317 if device < 0 or device >= device_count():
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in init()
159 Does nothing if the CUDA state is already initialized.
160 """
--> 161 _lazy_init()
162
163
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
176 raise RuntimeError(
177 "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178 _check_driver()
179 torch._C._cuda_init()
180 _cudart = _load_cudart()
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
90 def _check_driver():
91 if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92 raise AssertionError("Torch not compiled with CUDA enabled")
93 if not torch._C._cuda_isDriverSufficient():
94 if torch._C._cuda_getDriverVersion() == 0:
AssertionError: Torch not compiled with CUDA enabled
根据索引号查看GPU名字:
torch.cuda.get_device_name(0) # 输出 'GeForce GTX 1050'
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-60-45381b0646d2> in <module>()
----> 1 torch.cuda.get_device_name(0) # 输出 'GeForce GTX 1050'
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_name(device)
291 if :attr:`device` is ``None`` (default).
292 """
--> 293 return get_device_properties(device).name
294
295
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_properties(device)
313 def get_device_properties(device):
314 if not _initialized:
--> 315 init() # will define _get_device_properties and _CudaDeviceProperties
316 device = _get_device_index(device, optional=True)
317 if device < 0 or device >= device_count():
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in init()
159 Does nothing if the CUDA state is already initialized.
160 """
--> 161 _lazy_init()
162
163
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
176 raise RuntimeError(
177 "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178 _check_driver()
179 torch._C._cuda_init()
180 _cudart = _load_cudart()
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
90 def _check_driver():
91 if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92 raise AssertionError("Torch not compiled with CUDA enabled")
93 if not torch._C._cuda_isDriverSufficient():
94 if torch._C._cuda_getDriverVersion() == 0:
AssertionError: Torch not compiled with CUDA enabled
同Tensor类似,PyTorch模型也可以通过.cuda转换到GPU上。我们可以通过检查模型的参数的device属性来查看存放模型的设备。
net = nn.Linear(3, 1)
list(net.parameters())[0].device
device(type='cpu')
可见模型在CPU上,将其转换到GPU上:
net.cuda()
list(net.parameters())[0].device
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-62-a9404ab4efc6> in <module>()
----> 1 net.cuda()
2 list(net.parameters())[0].device
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)
309 Module: self
310 """
--> 311 return self._apply(lambda t: t.cuda(device))
312
313 def cpu(self):
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)
228 # `with torch.no_grad():`
229 with torch.no_grad():
--> 230 param_applied = fn(param)
231 should_use_set_data = compute_should_use_set_data(param, param_applied)
232 if should_use_set_data:
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)
309 Module: self
310 """
--> 311 return self._apply(lambda t: t.cuda(device))
312
313 def cpu(self):
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
176 raise RuntimeError(
177 "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178 _check_driver()
179 torch._C._cuda_init()
180 _cudart = _load_cudart()
D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
90 def _check_driver():
91 if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92 raise AssertionError("Torch not compiled with CUDA enabled")
93 if not torch._C._cuda_isDriverSufficient():
94 if torch._C._cuda_getDriverVersion() == 0:
AssertionError: Torch not compiled with CUDA enabled